diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index 0054d246b..73d0d9746 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -327,12 +327,12 @@ func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem { } // GetTestVM implements Blockchainer interface. -func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM { +func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) { panic("TODO") } // GetStorageItems implements Blockchainer interface. -func (chain *FakeChain) GetStorageItems(id int32) (map[string]state.StorageItem, error) { +func (chain *FakeChain) GetStorageItems(id int32) ([]state.StorageItemWithKey, error) { panic("TODO") } diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index a5d9aeecd..77543467e 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -490,9 +490,8 @@ func (bc *Blockchain) jumpToStateInternal(p uint32, stage stateJumpStage) error // Firstly, remove all old genesis-related items. b := bc.dao.Store.Batch() bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) { - // Must copy here, #1468. - key := slice.Copy(k) - b.Delete(key) + // #1468, but don't need to copy here, because it is done by Store. + b.Delete(k) }) b.Put(jumpStageKey, []byte{byte(oldStorageItemsRemoved)}) err := bc.dao.Store.PutBatch(b) @@ -509,14 +508,12 @@ func (bc *Blockchain) jumpToStateInternal(p uint32, stage stateJumpStage) error if count >= maxStorageBatchSize { return } - // Must copy here, #1468. - oldKey := slice.Copy(k) - b.Delete(oldKey) + // #1468, but don't need to copy here, because it is done by Store. + b.Delete(k) key := make([]byte, len(k)) key[0] = byte(storage.STStorage) copy(key[1:], k[1:]) - value := slice.Copy(v) - b.Put(key, value) + b.Put(key, slice.Copy(v)) count += 2 }) if count > 0 { @@ -1039,7 +1036,7 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error v.LoadToken = contract.LoadToken(systemInterop) v.GasLimit = tx.SystemFee - err := v.Run() + err := systemInterop.Exec() var faultException string if !v.HasFailed() { _, err := systemInterop.DAO.Persist() @@ -1226,7 +1223,7 @@ func (bc *Blockchain) runPersist(script []byte, block *block.Block, cache dao.DA v := systemInterop.SpawnVM() v.LoadScriptWithFlags(script, callflag.All) v.SetPriceGetter(systemInterop.GetPrice) - if err := v.Run(); err != nil { + if err := systemInterop.Exec(); err != nil { return nil, fmt.Errorf("VM has failed: %w", err) } else if _, err := systemInterop.DAO.Persist(); err != nil { return nil, fmt.Errorf("can't save changes: %w", err) @@ -1494,7 +1491,7 @@ func (bc *Blockchain) GetStorageItem(id int32, key []byte) state.StorageItem { } // GetStorageItems returns all storage items for a given contract id. -func (bc *Blockchain) GetStorageItems(id int32) (map[string]state.StorageItem, error) { +func (bc *Blockchain) GetStorageItems(id int32) ([]state.StorageItemWithKey, error) { return bc.dao.GetStorageItems(id) } @@ -2055,14 +2052,14 @@ func (bc *Blockchain) GetEnrollments() ([]state.Validator, error) { return bc.contracts.NEO.GetCandidates(bc.dao) } -// GetTestVM returns a VM and a Store setup for a test run of some sort of code. -func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM { +// GetTestVM returns a VM setup for a test run of some sort of code and finalizer function. +func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) { d := bc.dao.GetWrapped().(*dao.Simple) systemInterop := bc.newInteropContext(t, d, b, tx) vm := systemInterop.SpawnVM() vm.SetPriceGetter(systemInterop.GetPrice) vm.LoadToken = contract.LoadToken(systemInterop) - return vm + return vm, systemInterop.Finalize } // Various witness verification errors. @@ -2141,7 +2138,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa if err := bc.InitVerificationVM(vm, interopCtx.GetContract, hash, witness); err != nil { return 0, err } - err := vm.Run() + err := interopCtx.Exec() if vm.HasFailed() { return 0, fmt.Errorf("%w: vm execution has failed: %v", ErrVerificationFailed, err) } diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index 97bd64611..b63bca04c 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -58,8 +58,8 @@ type Blockchainer interface { GetStateModule() StateRoot GetStateSyncModule() StateSync GetStorageItem(id int32, key []byte) state.StorageItem - GetStorageItems(id int32) (map[string]state.StorageItem, error) - GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM + GetStorageItems(id int32) ([]state.StorageItemWithKey, error) + GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) SetOracle(service services.Oracle) mempool.Feer // fee interface diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index 2a509344f..b15a36e97 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -2,6 +2,7 @@ package dao import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -16,7 +17,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/util/slice" ) // HasTransaction errors. @@ -48,8 +48,8 @@ type DAO interface { GetStateSyncPoint() (uint32, error) GetStateSyncCurrentBlockHeight() (uint32, error) GetStorageItem(id int32, key []byte) state.StorageItem - GetStorageItems(id int32) (map[string]state.StorageItem, error) - GetStorageItemsWithPrefix(id int32, prefix []byte) (map[string]state.StorageItem, error) + GetStorageItems(id int32) ([]state.StorageItemWithKey, error) + GetStorageItemsWithPrefix(id int32, prefix []byte) ([]state.StorageItemWithKey, error) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) GetVersion() (string, error) GetWrapped() DAO @@ -65,6 +65,7 @@ type DAO interface { PutStorageItem(id int32, key []byte, si state.StorageItem) error PutVersion(v string) error Seek(id int32, prefix []byte, f func(k, v []byte)) + SeekAsync(ctx context.Context, id int32, prefix []byte) chan storage.KeyValue StoreAsBlock(block *block.Block, buf *io.BufBinWriter) error StoreAsCurrentBlock(block *block.Block, buf *io.BufBinWriter) error StoreAsTransaction(tx *transaction.Transaction, index uint32, buf *io.BufBinWriter) error @@ -313,28 +314,29 @@ func (dao *Simple) DeleteStorageItem(id int32, key []byte) error { } // GetStorageItems returns all storage items for a given id. -func (dao *Simple) GetStorageItems(id int32) (map[string]state.StorageItem, error) { +func (dao *Simple) GetStorageItems(id int32) ([]state.StorageItemWithKey, error) { return dao.GetStorageItemsWithPrefix(id, nil) } // GetStorageItemsWithPrefix returns all storage items with given id for a // given scripthash. -func (dao *Simple) GetStorageItemsWithPrefix(id int32, prefix []byte) (map[string]state.StorageItem, error) { - var siMap = make(map[string]state.StorageItem) +func (dao *Simple) GetStorageItemsWithPrefix(id int32, prefix []byte) ([]state.StorageItemWithKey, error) { + var siArr []state.StorageItemWithKey - saveToMap := func(k, v []byte) { + saveToArr := func(k, v []byte) { // Cut prefix and hash. - // Must copy here, #1468. - key := slice.Copy(k) - val := slice.Copy(v) - siMap[string(key)] = state.StorageItem(val) + // #1468, but don't need to copy here, because it is done by Store. + siArr = append(siArr, state.StorageItemWithKey{ + Key: k, + Item: state.StorageItem(v), + }) } - dao.Seek(id, prefix, saveToMap) - return siMap, nil + dao.Seek(id, prefix, saveToArr) + return siArr, nil } // Seek executes f for all items with a given prefix. -// If key is to be used outside of f, they must be copied. +// If key is to be used outside of f, they may not be copied. func (dao *Simple) Seek(id int32, prefix []byte, f func(k, v []byte)) { lookupKey := makeStorageItemKey(id, nil) if prefix != nil { @@ -345,6 +347,16 @@ func (dao *Simple) Seek(id int32, prefix []byte, f func(k, v []byte)) { }) } +// SeekAsync sends all storage items matching given prefix to a channel and returns +// the channel. Resulting keys and values may not be copied. +func (dao *Simple) SeekAsync(ctx context.Context, id int32, prefix []byte) chan storage.KeyValue { + lookupKey := makeStorageItemKey(id, nil) + if prefix != nil { + lookupKey = append(lookupKey, prefix...) + } + return dao.Store.SeekAsync(ctx, lookupKey, true) +} + // makeStorageItemKey returns a key used to store StorageItem in the DB. func makeStorageItemKey(id int32, key []byte) []byte { // 1 for prefix + 4 for Uint32 + len(key) for key diff --git a/pkg/core/interop/context.go b/pkg/core/interop/context.go index 908b8a6ac..1c7a3ad6e 100644 --- a/pkg/core/interop/context.go +++ b/pkg/core/interop/context.go @@ -1,6 +1,7 @@ package interop import ( + "context" "encoding/binary" "errors" "fmt" @@ -47,6 +48,7 @@ type Context struct { Log *zap.Logger VM *vm.VM Functions []Function + cancelFuncs []context.CancelFunc getContract func(dao.DAO, util.Uint160) (*state.Contract, error) baseExecFee int64 } @@ -285,3 +287,25 @@ func (ic *Context) SpawnVM() *vm.VM { ic.VM = v return v } + +// RegisterCancelFunc adds given function to the list of functions to be called after VM +// finishes script execution. +func (ic *Context) RegisterCancelFunc(f context.CancelFunc) { + if f != nil { + ic.cancelFuncs = append(ic.cancelFuncs, f) + } +} + +// Finalize calls all registered cancel functions to release the occupied resources. +func (ic *Context) Finalize() { + for _, f := range ic.cancelFuncs { + f() + } + ic.cancelFuncs = nil +} + +// Exec executes loaded VM script and calls registered finalizers to release the occupied resources. +func (ic *Context) Exec() error { + defer ic.Finalize() + return ic.VM.Run() +} diff --git a/pkg/core/interop/storage/find.go b/pkg/core/interop/storage/find.go index 045c9290f..35b6621f4 100644 --- a/pkg/core/interop/storage/find.go +++ b/pkg/core/interop/storage/find.go @@ -1,6 +1,10 @@ package storage -import "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +import ( + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/util/slice" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) // Storage iterator options. const ( @@ -18,44 +22,45 @@ const ( // Iterator is an iterator state representation. type Iterator struct { - m []stackitem.MapElement - opts int64 - index int - prefixSize int + seekCh chan storage.KeyValue + curr storage.KeyValue + next bool + opts int64 + prefix []byte } -// NewIterator creates a new Iterator with given options for a given map. -func NewIterator(m *stackitem.Map, prefix int, opts int64) *Iterator { +// NewIterator creates a new Iterator with given options for a given channel of store.Seek results. +func NewIterator(seekCh chan storage.KeyValue, prefix []byte, opts int64) *Iterator { return &Iterator{ - m: m.Value().([]stackitem.MapElement), - opts: opts, - index: -1, - prefixSize: prefix, + seekCh: seekCh, + opts: opts, + prefix: slice.Copy(prefix), } } // Next advances the iterator and returns true if Value can be called at the // current position. func (s *Iterator) Next() bool { - if s.index < len(s.m) { - s.index++ - } - return s.index < len(s.m) + s.curr, s.next = <-s.seekCh + return s.next } // Value returns current iterators value (exact type depends on options this // iterator was created with). func (s *Iterator) Value() stackitem.Item { - key := s.m[s.index].Key.Value().([]byte) - if s.opts&FindRemovePrefix != 0 { - key = key[s.prefixSize:] + if !s.next { + panic("iterator index out of range") + } + key := s.curr.Key + if s.opts&FindRemovePrefix == 0 { + key = append(s.prefix, key...) } if s.opts&FindKeysOnly != 0 { return stackitem.NewByteArray(key) } - value := s.m[s.index].Value + value := stackitem.Item(stackitem.NewByteArray(s.curr.Value)) if s.opts&FindDeserialize != 0 { - bs := s.m[s.index].Value.Value().([]byte) + bs := s.curr.Value var err error value, err = stackitem.Deserialize(bs) if err != nil { diff --git a/pkg/core/interop_system.go b/pkg/core/interop_system.go index a56ef18cd..fd473b65b 100644 --- a/pkg/core/interop_system.go +++ b/pkg/core/interop_system.go @@ -1,13 +1,12 @@ package core import ( - "bytes" + "context" "crypto/elliptic" "errors" "fmt" "math" "math/big" - "sort" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/interop" @@ -188,30 +187,13 @@ func storageFind(ic *interop.Context) error { if opts&istorage.FindDeserialize == 0 && (opts&istorage.FindPick0 != 0 || opts&istorage.FindPick1 != 0) { return fmt.Errorf("%w: PickN is specified without Deserialize", errFindInvalidOptions) } - siMap, err := ic.DAO.GetStorageItemsWithPrefix(stc.ID, prefix) - if err != nil { - return err - } - - arr := make([]stackitem.MapElement, 0, len(siMap)) - for k, v := range siMap { - keycopy := make([]byte, len(k)+len(prefix)) - copy(keycopy, prefix) - copy(keycopy[len(prefix):], k) - arr = append(arr, stackitem.MapElement{ - Key: stackitem.NewByteArray(keycopy), - Value: stackitem.NewByteArray(v), - }) - } - sort.Slice(arr, func(i, j int) bool { - k1 := arr[i].Key.Value().([]byte) - k2 := arr[j].Key.Value().([]byte) - return bytes.Compare(k1, k2) == -1 - }) - - filteredMap := stackitem.NewMapWithValue(arr) - item := istorage.NewIterator(filteredMap, len(prefix), opts) + // Items in seekres should be sorted by key, but GetStorageItemsWithPrefix returns + // sorted items, so no need to sort them one more time. + ctx, cancel := context.WithCancel(context.Background()) + seekres := ic.DAO.SeekAsync(ctx, stc.ID, prefix) + item := istorage.NewIterator(seekres, prefix, opts) ic.VM.Estack().PushItem(stackitem.NewInterop(item)) + ic.RegisterCancelFunc(cancel) return nil } diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index 809f38b12..f86b73858 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -2,6 +2,7 @@ package core import ( "errors" + "fmt" "math" "math/big" "testing" @@ -336,32 +337,101 @@ func TestStorageDelete(t *testing.T) { } func BenchmarkStorageFind(b *testing.B) { - v, contractState, context, chain := createVMAndContractState(b) - require.NoError(b, chain.contracts.Management.PutContractState(chain.dao, contractState)) + for count := 10; count <= 10000; count *= 10 { + b.Run(fmt.Sprintf("%dElements", count), func(b *testing.B) { + v, contractState, context, chain := createVMAndContractState(b) + require.NoError(b, chain.contracts.Management.PutContractState(chain.dao, contractState)) - const count = 100 + items := make(map[string]state.StorageItem) + for i := 0; i < count; i++ { + items["abc"+random.String(10)] = random.Bytes(10) + } + for k, v := range items { + require.NoError(b, context.DAO.PutStorageItem(contractState.ID, []byte(k), v)) + require.NoError(b, context.DAO.PutStorageItem(contractState.ID+1, []byte(k), v)) + } + changes, err := context.DAO.Persist() + require.NoError(b, err) + require.NotEqual(b, 0, changes) - items := make(map[string]state.StorageItem) - for i := 0; i < count; i++ { - items["abc"+random.String(10)] = random.Bytes(10) - } - for k, v := range items { - require.NoError(b, context.DAO.PutStorageItem(contractState.ID, []byte(k), v)) - require.NoError(b, context.DAO.PutStorageItem(contractState.ID+1, []byte(k), v)) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + b.StopTimer() + v.Estack().PushVal(istorage.FindDefault) + v.Estack().PushVal("abc") + v.Estack().PushVal(stackitem.NewInterop(&StorageContext{ID: contractState.ID})) + b.StartTimer() + err := storageFind(context) + if err != nil { + b.FailNow() + } + b.StopTimer() + context.Finalize() + } + }) } +} - b.ResetTimer() - b.ReportAllocs() - for i := 0; i < b.N; i++ { - b.StopTimer() - v.Estack().PushVal(istorage.FindDefault) - v.Estack().PushVal("abc") - v.Estack().PushVal(stackitem.NewInterop(&StorageContext{ID: contractState.ID})) - b.StartTimer() - err := storageFind(context) - if err != nil { - b.FailNow() +func BenchmarkStorageFindIteratorNext(b *testing.B) { + for count := 10; count <= 10000; count *= 10 { + cases := map[string]int{ + "Pick1": 1, + "PickHalf": count / 2, + "PickAll": count, } + b.Run(fmt.Sprintf("%dElements", count), func(b *testing.B) { + for name, last := range cases { + b.Run(name, func(b *testing.B) { + v, contractState, context, chain := createVMAndContractState(b) + require.NoError(b, chain.contracts.Management.PutContractState(chain.dao, contractState)) + + items := make(map[string]state.StorageItem) + for i := 0; i < count; i++ { + items["abc"+random.String(10)] = random.Bytes(10) + } + for k, v := range items { + require.NoError(b, context.DAO.PutStorageItem(contractState.ID, []byte(k), v)) + require.NoError(b, context.DAO.PutStorageItem(contractState.ID+1, []byte(k), v)) + } + changes, err := context.DAO.Persist() + require.NoError(b, err) + require.NotEqual(b, 0, changes) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + v.Estack().PushVal(istorage.FindDefault) + v.Estack().PushVal("abc") + v.Estack().PushVal(stackitem.NewInterop(&StorageContext{ID: contractState.ID})) + b.StartTimer() + err := storageFind(context) + b.StopTimer() + if err != nil { + b.FailNow() + } + res := context.VM.Estack().Pop().Item() + for i := 0; i < last; i++ { + context.VM.Estack().PushVal(res) + b.StartTimer() + require.NoError(b, iterator.Next(context)) + b.StopTimer() + require.True(b, context.VM.Estack().Pop().Bool()) + } + + context.VM.Estack().PushVal(res) + require.NoError(b, iterator.Next(context)) + actual := context.VM.Estack().Pop().Bool() + if last == count { + require.False(b, actual) + } else { + require.True(b, actual) + } + context.Finalize() + } + }) + } + }) } } diff --git a/pkg/core/native/designate.go b/pkg/core/native/designate.go index 844690a3d..ccf5767b8 100644 --- a/pkg/core/native/designate.go +++ b/pkg/core/native/designate.go @@ -257,17 +257,22 @@ func (s *Designate) GetDesignatedByRole(d dao.DAO, r noderoles.Role, index uint3 if err != nil { return nil, 0, err } - var ns NodeList - var bestIndex uint32 - var resSi state.StorageItem - for k, si := range kvs { - if len(k) < 4 { + var ( + ns NodeList + bestIndex uint32 + resSi state.StorageItem + ) + // kvs are sorted by key (BE index bytes) in ascending way, so iterate backwards to get the latest designated. + for i := len(kvs) - 1; i >= 0; i-- { + kv := kvs[i] + if len(kv.Key) < 4 { continue } - siInd := binary.BigEndian.Uint32([]byte(k)) - if (resSi == nil || siInd > bestIndex) && siInd <= index { + siInd := binary.BigEndian.Uint32(kv.Key) + if siInd <= index { bestIndex = siInd - resSi = si + resSi = kv.Item + break } } if resSi != nil { diff --git a/pkg/core/native/management.go b/pkg/core/native/management.go index c66e9dc8d..7403559e7 100644 --- a/pkg/core/native/management.go +++ b/pkg/core/native/management.go @@ -391,12 +391,12 @@ func (m *Management) Destroy(d dao.DAO, hash util.Uint160) error { if err != nil { return err } - siMap, err := d.GetStorageItems(contract.ID) + siArr, err := d.GetStorageItems(contract.ID) if err != nil { return err } - for k := range siMap { - err := d.DeleteStorageItem(contract.ID, []byte(k)) + for _, kv := range siArr { + err := d.DeleteStorageItem(contract.ID, []byte(kv.Key)) if err != nil { return err } diff --git a/pkg/core/native/native_neo.go b/pkg/core/native/native_neo.go index 9df6e468b..53604bca3 100644 --- a/pkg/core/native/native_neo.go +++ b/pkg/core/native/native_neo.go @@ -471,24 +471,20 @@ func (n *NEO) getGASPerBlock(ic *interop.Context, _ []stackitem.Item) stackitem. } func (n *NEO) getSortedGASRecordFromDAO(d dao.DAO) (gasRecord, error) { - grMap, err := d.GetStorageItemsWithPrefix(n.ID, []byte{prefixGASPerBlock}) + grArr, err := d.GetStorageItemsWithPrefix(n.ID, []byte{prefixGASPerBlock}) if err != nil { return gasRecord{}, fmt.Errorf("failed to get gas records from storage: %w", err) } - var ( - i int - gr = make(gasRecord, len(grMap)) - ) - for indexBytes, gasValue := range grMap { + var gr = make(gasRecord, len(grArr)) + for i, kv := range grArr { + indexBytes, gasValue := kv.Key, kv.Item gr[i] = gasIndexPair{ Index: binary.BigEndian.Uint32([]byte(indexBytes)), GASPerBlock: *bigint.FromBytes(gasValue), } - i++ } - sort.Slice(gr, func(i, j int) bool { - return gr[i].Index < gr[j].Index - }) + // GAS records should be sorted by index, but GetStorageItemsWithPrefix returns + // values sorted by BE bytes of index, so we're OK with that. return gr, nil } @@ -836,21 +832,21 @@ func (n *NEO) ModifyAccountVotes(acc *state.NEOBalance, d dao.DAO, value *big.In } func (n *NEO) getCandidates(d dao.DAO, sortByKey bool) ([]keyWithVotes, error) { - siMap, err := d.GetStorageItemsWithPrefix(n.ID, []byte{prefixCandidate}) + siArr, err := d.GetStorageItemsWithPrefix(n.ID, []byte{prefixCandidate}) if err != nil { return nil, err } - arr := make([]keyWithVotes, 0, len(siMap)) - for key, si := range siMap { - c := new(candidate).FromBytes(si) + arr := make([]keyWithVotes, 0, len(siArr)) + for _, kv := range siArr { + c := new(candidate).FromBytes(kv.Item) if c.Registered { - arr = append(arr, keyWithVotes{Key: key, Votes: &c.Votes}) + arr = append(arr, keyWithVotes{Key: string(kv.Key), Votes: &c.Votes}) } } - if sortByKey { - // Sort by serialized key bytes (that's the way keys are stored and retrieved from the storage by default). - sort.Slice(arr, func(i, j int) bool { return strings.Compare(arr[i].Key, arr[j].Key) == -1 }) - } else { + if !sortByKey { + // sortByKey assumes to sort by serialized key bytes (that's the way keys + // are stored and retrieved from the storage by default). Otherwise, need + // to sort using big.Int comparator. sort.Slice(arr, func(i, j int) bool { // The most-voted validators should end up in the front of the list. cmp := arr[i].Votes.Cmp(arr[j].Votes) diff --git a/pkg/core/native/oracle.go b/pkg/core/native/oracle.go index c060a9739..393c0d979 100644 --- a/pkg/core/native/oracle.go +++ b/pkg/core/native/oracle.go @@ -483,21 +483,21 @@ func (o *Oracle) getOriginalTxID(d dao.DAO, tx *transaction.Transaction) util.Ui // getRequests returns all requests which have not been finished yet. func (o *Oracle) getRequests(d dao.DAO) (map[uint64]*state.OracleRequest, error) { - m, err := d.GetStorageItemsWithPrefix(o.ID, prefixRequest) + arr, err := d.GetStorageItemsWithPrefix(o.ID, prefixRequest) if err != nil { return nil, err } - reqs := make(map[uint64]*state.OracleRequest, len(m)) - for k, si := range m { - if len(k) != 8 { + reqs := make(map[uint64]*state.OracleRequest, len(arr)) + for _, kv := range arr { + if len(kv.Key) != 8 { return nil, errors.New("invalid request ID") } req := new(state.OracleRequest) - err = stackitem.DeserializeConvertible(si, req) + err = stackitem.DeserializeConvertible(kv.Item, req) if err != nil { return nil, err } - id := binary.BigEndian.Uint64([]byte(k)) + id := binary.BigEndian.Uint64([]byte(kv.Key)) reqs[id] = req } return reqs, nil diff --git a/pkg/core/native/policy.go b/pkg/core/native/policy.go index e377f0447..dffe1371c 100644 --- a/pkg/core/native/policy.go +++ b/pkg/core/native/policy.go @@ -162,20 +162,20 @@ func (p *Policy) PostPersist(ic *interop.Context) error { p.storagePrice = uint32(getIntWithKey(p.ID, ic.DAO, storagePriceKey)) p.blockedAccounts = make([]util.Uint160, 0) - siMap, err := ic.DAO.GetStorageItemsWithPrefix(p.ID, []byte{blockedAccountPrefix}) + siArr, err := ic.DAO.GetStorageItemsWithPrefix(p.ID, []byte{blockedAccountPrefix}) if err != nil { return fmt.Errorf("failed to get blocked accounts from storage: %w", err) } - for key := range siMap { - hash, err := util.Uint160DecodeBytesBE([]byte(key)) + for _, kv := range siArr { + hash, err := util.Uint160DecodeBytesBE([]byte(kv.Key)) if err != nil { return fmt.Errorf("failed to decode blocked account hash: %w", err) } p.blockedAccounts = append(p.blockedAccounts, hash) } - sort.Slice(p.blockedAccounts, func(i, j int) bool { - return p.blockedAccounts[i].Less(p.blockedAccounts[j]) - }) + // blockedAccounts should be sorted by account BE bytes, but GetStorageItemsWithPrefix + // returns values sorted by key (which is account's BE bytes), so don't need to sort + // one more time. p.isValid = true return nil diff --git a/pkg/core/state/storage_item.go b/pkg/core/state/storage_item.go index ba9777cf8..0a6eebb23 100644 --- a/pkg/core/state/storage_item.go +++ b/pkg/core/state/storage_item.go @@ -2,3 +2,9 @@ package state // StorageItem is the value to be stored with read-only flag. type StorageItem []byte + +// StorageItemWithKey is a storage item with corresponding key. +type StorageItemWithKey struct { + Key []byte + Item StorageItem +} diff --git a/pkg/core/stateroot/module.go b/pkg/core/stateroot/module.go index a96bac1f1..fdc053595 100644 --- a/pkg/core/stateroot/module.go +++ b/pkg/core/stateroot/module.go @@ -13,7 +13,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/util/slice" "go.uber.org/atomic" "go.uber.org/zap" ) @@ -147,9 +146,8 @@ func (s *Module) CleanStorage() error { // b := s.Store.Batch() s.Store.Seek([]byte{byte(storage.DataMPT)}, func(k, _ []byte) { - // Must copy here, #1468. - key := slice.Copy(k) - b.Delete(key) + // #1468, but don't need to copy here, because it is done by Store. + b.Delete(k) }) err = s.Store.PutBatch(b) if err != nil { diff --git a/pkg/core/storage/badgerdb_store_test.go b/pkg/core/storage/badgerdb_store_test.go index e00fa4a15..78a8fb957 100644 --- a/pkg/core/storage/badgerdb_store_test.go +++ b/pkg/core/storage/badgerdb_store_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" ) -func newBadgerDBForTesting(t *testing.T) Store { +func newBadgerDBForTesting(t testing.TB) Store { bdbDir := t.TempDir() dbConfig := DBConfiguration{ Type: "badgerdb", diff --git a/pkg/core/storage/boltdb_store_test.go b/pkg/core/storage/boltdb_store_test.go index 9f8fd4caf..cd1a84e39 100644 --- a/pkg/core/storage/boltdb_store_test.go +++ b/pkg/core/storage/boltdb_store_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func newBoltStoreForTesting(t *testing.T) Store { +func newBoltStoreForTesting(t testing.TB) Store { d := t.TempDir() testFileName := path.Join(d, "test_bolt_db") boltDBStore, err := NewBoltDBStore(BoltDBOptions{FilePath: testFileName}) diff --git a/pkg/core/storage/leveldb_store_test.go b/pkg/core/storage/leveldb_store_test.go index 5d8672e7a..24f21e76b 100644 --- a/pkg/core/storage/leveldb_store_test.go +++ b/pkg/core/storage/leveldb_store_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" ) -func newLevelDBForTesting(t *testing.T) Store { +func newLevelDBForTesting(t testing.TB) Store { ldbDir := t.TempDir() dbConfig := DBConfiguration{ Type: "leveldb", diff --git a/pkg/core/storage/memcached_store.go b/pkg/core/storage/memcached_store.go index b21d3edaa..afe32481b 100644 --- a/pkg/core/storage/memcached_store.go +++ b/pkg/core/storage/memcached_store.go @@ -1,6 +1,14 @@ package storage -import "sync" +import ( + "bytes" + "context" + "sort" + "strings" + "sync" + + "github.com/nspcc-dev/neo-go/pkg/util/slice" +) // MemCachedStore is a wrapper around persistent store that caches all changes // being made for them to be later flushed in one batch. @@ -18,14 +26,20 @@ type ( KeyValue struct { Key []byte Value []byte + } + + // KeyValueExists represents key-value pair with indicator whether the item + // exists in the persistent storage. + KeyValueExists struct { + KeyValue Exists bool } // MemBatch represents a changeset to be persisted. MemBatch struct { - Put []KeyValue - Deleted []KeyValue + Put []KeyValueExists + Deleted []KeyValueExists } ) @@ -58,18 +72,18 @@ func (s *MemCachedStore) GetBatch() *MemBatch { var b MemBatch - b.Put = make([]KeyValue, 0, len(s.mem)) + b.Put = make([]KeyValueExists, 0, len(s.mem)) for k, v := range s.mem { key := []byte(k) _, err := s.ps.Get(key) - b.Put = append(b.Put, KeyValue{Key: key, Value: v, Exists: err == nil}) + b.Put = append(b.Put, KeyValueExists{KeyValue: KeyValue{Key: key, Value: v}, Exists: err == nil}) } - b.Deleted = make([]KeyValue, 0, len(s.del)) + b.Deleted = make([]KeyValueExists, 0, len(s.del)) for k := range s.del { key := []byte(k) _, err := s.ps.Get(key) - b.Deleted = append(b.Deleted, KeyValue{Key: key, Exists: err == nil}) + b.Deleted = append(b.Deleted, KeyValueExists{KeyValue: KeyValue{Key: key}, Exists: err == nil}) } return &b @@ -77,21 +91,130 @@ func (s *MemCachedStore) GetBatch() *MemBatch { // Seek implements the Store interface. func (s *MemCachedStore) Seek(key []byte, f func(k, v []byte)) { + s.seek(context.Background(), key, false, f) +} + +// SeekAsync returns non-buffered channel with matching KeyValue pairs. Key and +// value slices may not be copied and may be modified. SeekAsync can guarantee +// that key-value items are sorted by key in ascending way. +func (s *MemCachedStore) SeekAsync(ctx context.Context, key []byte, cutPrefix bool) chan KeyValue { + res := make(chan KeyValue) + go func() { + s.seek(ctx, key, cutPrefix, func(k, v []byte) { + res <- KeyValue{ + Key: k, + Value: v, + } + }) + close(res) + }() + + return res +} + +func (s *MemCachedStore) seek(ctx context.Context, key []byte, cutPrefix bool, f func(k, v []byte)) { + // Create memory store `mem` and `del` snapshot not to hold the lock. + var memRes []KeyValueExists + sk := string(key) s.mut.RLock() - defer s.mut.RUnlock() - s.MemoryStore.seek(key, f) - s.ps.Seek(key, func(k, v []byte) { - elem := string(k) - // If it's in mem, we already called f() for it in MemoryStore.Seek(). - _, present := s.mem[elem] - if !present { - // If it's in del, we shouldn't be calling f() anyway. - _, present = s.del[elem] + for k, v := range s.MemoryStore.mem { + if strings.HasPrefix(k, sk) { + memRes = append(memRes, KeyValueExists{ + KeyValue: KeyValue{ + Key: []byte(k), + Value: v, + }, + Exists: true, + }) } - if !present { - f(k, v) + } + for k := range s.MemoryStore.del { + if strings.HasPrefix(k, sk) { + memRes = append(memRes, KeyValueExists{ + KeyValue: KeyValue{ + Key: []byte(k), + }, + }) + } + } + ps := s.ps + s.mut.RUnlock() + // Sort memRes items for further comparison with ps items. + sort.Slice(memRes, func(i, j int) bool { + return bytes.Compare(memRes[i].Key, memRes[j].Key) < 0 + }) + + var ( + done bool + iMem int + kvMem KeyValueExists + haveMem bool + ) + if iMem < len(memRes) { + kvMem = memRes[iMem] + haveMem = true + iMem++ + } + // Merge results of seek operations in ascending order. + ps.Seek(key, func(k, v []byte) { + if done { + return + } + kvPs := KeyValue{ + Key: slice.Copy(k), + Value: slice.Copy(v), + } + loop: + for { + select { + case <-ctx.Done(): + done = true + break loop + default: + var isMem = haveMem && (bytes.Compare(kvMem.Key, kvPs.Key) < 0) + if isMem { + if kvMem.Exists { + if cutPrefix { + kvMem.Key = kvMem.Key[len(key):] + } + f(kvMem.Key, kvMem.Value) + } + if iMem < len(memRes) { + kvMem = memRes[iMem] + haveMem = true + iMem++ + } else { + haveMem = false + } + } else { + if !bytes.Equal(kvMem.Key, kvPs.Key) { + if cutPrefix { + kvPs.Key = kvPs.Key[len(key):] + } + f(kvPs.Key, kvPs.Value) + } + break loop + } + } } }) + if !done && haveMem { + loop: + for i := iMem - 1; i < len(memRes); i++ { + select { + case <-ctx.Done(): + break loop + default: + kvMem = memRes[i] + if kvMem.Exists { + if cutPrefix { + kvMem.Key = kvMem.Key[len(key):] + } + f(kvMem.Key, kvMem.Value) + } + } + } + } } // Persist flushes all the MemoryStore contents into the (supposedly) persistent diff --git a/pkg/core/storage/memcached_store_test.go b/pkg/core/storage/memcached_store_test.go index fe811f61d..216c6b8e7 100644 --- a/pkg/core/storage/memcached_store_test.go +++ b/pkg/core/storage/memcached_store_test.go @@ -1,8 +1,13 @@ package storage import ( + "bytes" + "fmt" + "sort" "testing" + "github.com/nspcc-dev/neo-go/internal/random" + "github.com/nspcc-dev/neo-go/pkg/util/slice" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,7 +21,7 @@ func testMemCachedStorePersist(t *testing.T, ps Store) { assert.Equal(t, 0, c) // persisting one key should result in one key in ps and nothing in ts assert.NoError(t, ts.Put([]byte("key"), []byte("value"))) - checkBatch(t, ts, []KeyValue{{Key: []byte("key"), Value: []byte("value")}}, nil) + checkBatch(t, ts, []KeyValueExists{{KeyValue: KeyValue{Key: []byte("key"), Value: []byte("value")}}}, nil) c, err = ts.Persist() checkBatch(t, ts, nil, nil) assert.Equal(t, nil, err) @@ -35,9 +40,9 @@ func testMemCachedStorePersist(t *testing.T, ps Store) { v, err = ps.Get([]byte("key2")) assert.Equal(t, ErrKeyNotFound, err) assert.Equal(t, []byte(nil), v) - checkBatch(t, ts, []KeyValue{ - {Key: []byte("key"), Value: []byte("newvalue"), Exists: true}, - {Key: []byte("key2"), Value: []byte("value2")}, + checkBatch(t, ts, []KeyValueExists{ + {KeyValue: KeyValue{Key: []byte("key"), Value: []byte("newvalue")}, Exists: true}, + {KeyValue: KeyValue{Key: []byte("key2"), Value: []byte("value2")}}, }, nil) // two keys should be persisted (one overwritten and one new) and // available in the ps @@ -65,7 +70,7 @@ func testMemCachedStorePersist(t *testing.T, ps Store) { // test persisting deletions err = ts.Delete([]byte("key")) assert.Equal(t, nil, err) - checkBatch(t, ts, nil, []KeyValue{{Key: []byte("key"), Exists: true}}) + checkBatch(t, ts, nil, []KeyValueExists{{KeyValue: KeyValue{Key: []byte("key")}, Exists: true}}) c, err = ts.Persist() checkBatch(t, ts, nil, nil) assert.Equal(t, nil, err) @@ -78,7 +83,7 @@ func testMemCachedStorePersist(t *testing.T, ps Store) { assert.Equal(t, []byte("value2"), v) } -func checkBatch(t *testing.T, ts *MemCachedStore, put []KeyValue, del []KeyValue) { +func checkBatch(t *testing.T, ts *MemCachedStore, put []KeyValueExists, del []KeyValueExists) { b := ts.GetBatch() assert.Equal(t, len(put), len(b.Put), "wrong number of put elements in a batch") assert.Equal(t, len(del), len(b.Deleted), "wrong number of deleted elements in a batch") @@ -174,7 +179,82 @@ func TestCachedSeek(t *testing.T) { } } -func newMemCachedStoreForTesting(t *testing.T) Store { +func benchmarkCachedSeek(t *testing.B, ps Store, psElementsCount, tsElementsCount int) { + var ( + searchPrefix = []byte{1} + badPrefix = []byte{2} + lowerPrefixGood = append(searchPrefix, 1) + lowerPrefixBad = append(badPrefix, 1) + deletedPrefixGood = append(searchPrefix, 2) + deletedPrefixBad = append(badPrefix, 2) + updatedPrefixGood = append(searchPrefix, 3) + updatedPrefixBad = append(badPrefix, 3) + + ts = NewMemCachedStore(ps) + ) + for i := 0; i < psElementsCount; i++ { + // lower KVs with matching prefix that should be found + require.NoError(t, ps.Put(append(lowerPrefixGood, random.Bytes(10)...), []byte("value"))) + // lower KVs with non-matching prefix that shouldn't be found + require.NoError(t, ps.Put(append(lowerPrefixBad, random.Bytes(10)...), []byte("value"))) + + // deleted KVs with matching prefix that shouldn't be found + key := append(deletedPrefixGood, random.Bytes(10)...) + require.NoError(t, ps.Put(key, []byte("deleted"))) + if i < tsElementsCount { + require.NoError(t, ts.Delete(key)) + } + // deleted KVs with non-matching prefix that shouldn't be found + key = append(deletedPrefixBad, random.Bytes(10)...) + require.NoError(t, ps.Put(key, []byte("deleted"))) + if i < tsElementsCount { + require.NoError(t, ts.Delete(key)) + } + + // updated KVs with matching prefix that should be found + key = append(updatedPrefixGood, random.Bytes(10)...) + require.NoError(t, ps.Put(key, []byte("stub"))) + if i < tsElementsCount { + require.NoError(t, ts.Put(key, []byte("updated"))) + } + // updated KVs with non-matching prefix that shouldn't be found + key = append(updatedPrefixBad, random.Bytes(10)...) + require.NoError(t, ps.Put(key, []byte("stub"))) + if i < tsElementsCount { + require.NoError(t, ts.Put(key, []byte("updated"))) + } + } + + t.ReportAllocs() + t.ResetTimer() + for n := 0; n < t.N; n++ { + ts.Seek(searchPrefix, func(k, v []byte) {}) + } + t.StopTimer() +} + +func BenchmarkCachedSeek(t *testing.B) { + var stores = map[string]func(testing.TB) Store{ + "MemPS": func(t testing.TB) Store { + return NewMemoryStore() + }, + "BoltPS": newBoltStoreForTesting, + "LevelPS": newLevelDBForTesting, + } + for psName, newPS := range stores { + for psCount := 100; psCount <= 10000; psCount *= 10 { + for tsCount := 10; tsCount <= psCount; tsCount *= 10 { + t.Run(fmt.Sprintf("%s_%dTSItems_%dPSItems", psName, tsCount, psCount), func(t *testing.B) { + ps := newPS(t) + benchmarkCachedSeek(t, ps, psCount, tsCount) + ps.Close() + }) + } + } + } +} + +func newMemCachedStoreForTesting(t testing.TB) Store { return NewMemCachedStore(NewMemoryStore()) } @@ -242,3 +322,52 @@ func TestMemCachedPersistFailing(t *testing.T) { require.NoError(t, err) require.Equal(t, b1, res) } + +func TestCachedSeekSorting(t *testing.T) { + var ( + // Given this prefix... + goodPrefix = []byte{1} + // these pairs should be found... + lowerKVs = []kvSeen{ + {[]byte{1, 2, 3}, []byte("bra"), false}, + {[]byte{1, 2, 5}, []byte("bar"), false}, + {[]byte{1, 3, 3}, []byte("bra"), false}, + {[]byte{1, 3, 5}, []byte("bra"), false}, + } + // and these should be not. + deletedKVs = []kvSeen{ + {[]byte{1, 7, 3}, []byte("pow"), false}, + {[]byte{1, 7, 4}, []byte("qaz"), false}, + } + // and these should be not. + updatedKVs = []kvSeen{ + {[]byte{1, 2, 4}, []byte("zaq"), false}, + {[]byte{1, 2, 6}, []byte("zaq"), false}, + {[]byte{1, 3, 2}, []byte("wop"), false}, + {[]byte{1, 3, 4}, []byte("zaq"), false}, + } + ps = NewMemoryStore() + ts = NewMemCachedStore(ps) + ) + for _, v := range lowerKVs { + require.NoError(t, ps.Put(v.key, v.val)) + } + for _, v := range deletedKVs { + require.NoError(t, ps.Put(v.key, v.val)) + require.NoError(t, ts.Delete(v.key)) + } + for _, v := range updatedKVs { + require.NoError(t, ps.Put(v.key, []byte("stub"))) + require.NoError(t, ts.Put(v.key, v.val)) + } + var foundKVs []kvSeen + ts.Seek(goodPrefix, func(k, v []byte) { + foundKVs = append(foundKVs, kvSeen{key: slice.Copy(k), val: slice.Copy(v)}) + }) + assert.Equal(t, len(foundKVs), len(lowerKVs)+len(updatedKVs)) + expected := append(lowerKVs, updatedKVs...) + sort.Slice(expected, func(i, j int) bool { + return bytes.Compare(expected[i].key, expected[j].key) < 0 + }) + require.Equal(t, expected, foundKVs) +} diff --git a/pkg/core/storage/memory_store.go b/pkg/core/storage/memory_store.go index 45c4ae9fd..21ca47645 100644 --- a/pkg/core/storage/memory_store.go +++ b/pkg/core/storage/memory_store.go @@ -1,6 +1,8 @@ package storage import ( + "bytes" + "sort" "strings" "sync" @@ -128,11 +130,21 @@ func (s *MemoryStore) SeekAll(key []byte, f func(k, v []byte)) { // seek is an internal unlocked implementation of Seek. func (s *MemoryStore) seek(key []byte, f func(k, v []byte)) { sk := string(key) + var memList []KeyValue for k, v := range s.mem { if strings.HasPrefix(k, sk) { - f([]byte(k), v) + memList = append(memList, KeyValue{ + Key: []byte(k), + Value: v, + }) } } + sort.Slice(memList, func(i, j int) bool { + return bytes.Compare(memList[i].Key, memList[j].Key) < 0 + }) + for _, kv := range memList { + f(kv.Key, kv.Value) + } } // Batch implements the Batch interface and returns a compatible Batch. diff --git a/pkg/core/storage/memory_store_test.go b/pkg/core/storage/memory_store_test.go index 259dbae68..6bb7d7526 100644 --- a/pkg/core/storage/memory_store_test.go +++ b/pkg/core/storage/memory_store_test.go @@ -1,9 +1,35 @@ package storage import ( + "fmt" "testing" + + "github.com/nspcc-dev/neo-go/internal/random" + "github.com/stretchr/testify/require" ) -func newMemoryStoreForTesting(t *testing.T) Store { +func newMemoryStoreForTesting(t testing.TB) Store { return NewMemoryStore() } + +func BenchmarkMemorySeek(t *testing.B) { + for count := 10; count <= 10000; count *= 10 { + t.Run(fmt.Sprintf("%dElements", count), func(t *testing.B) { + ms := NewMemoryStore() + var ( + searchPrefix = []byte{1} + badPrefix = []byte{2} + ) + for i := 0; i < count; i++ { + require.NoError(t, ms.Put(append(searchPrefix, random.Bytes(10)...), random.Bytes(10))) + require.NoError(t, ms.Put(append(badPrefix, random.Bytes(10)...), random.Bytes(10))) + } + + t.ReportAllocs() + t.ResetTimer() + for n := 0; n < t.N; n++ { + ms.Seek(searchPrefix, func(k, v []byte) {}) + } + }) + } +} diff --git a/pkg/core/storage/redis_store_test.go b/pkg/core/storage/redis_store_test.go index d5613303b..f18aca684 100644 --- a/pkg/core/storage/redis_store_test.go +++ b/pkg/core/storage/redis_store_test.go @@ -12,7 +12,7 @@ type mockedRedisStore struct { mini *miniredis.Miniredis } -func prepareRedisMock(t *testing.T) (*miniredis.Miniredis, *RedisStore) { +func prepareRedisMock(t testing.TB) (*miniredis.Miniredis, *RedisStore) { miniRedis, err := miniredis.Run() require.Nil(t, err, "MiniRedis mock creation error") @@ -37,7 +37,7 @@ func (mrs *mockedRedisStore) Close() error { return err } -func newRedisStoreForTesting(t *testing.T) Store { +func newRedisStoreForTesting(t testing.TB) Store { mock, rs := prepareRedisMock(t) mrs := &mockedRedisStore{RedisStore: *rs, mini: mock} return mrs diff --git a/pkg/core/storage/store.go b/pkg/core/storage/store.go index bd62f6001..f87f94e87 100644 --- a/pkg/core/storage/store.go +++ b/pkg/core/storage/store.go @@ -55,7 +55,8 @@ type ( // PutChangeSet allows to push prepared changeset to the Store. PutChangeSet(puts map[string][]byte, dels map[string]bool) error // Seek can guarantee that provided key (k) and value (v) are the only valid until the next call to f. - // Key and value slices should not be modified. + // Key and value slices should not be modified. Seek can guarantee that key-value items are sorted by + // key in ascending way. Seek(k []byte, f func(k, v []byte)) Close() error } diff --git a/pkg/core/storage/storeandbatch_test.go b/pkg/core/storage/storeandbatch_test.go index 022d83adb..70a3787ea 100644 --- a/pkg/core/storage/storeandbatch_test.go +++ b/pkg/core/storage/storeandbatch_test.go @@ -19,7 +19,7 @@ type kvSeen struct { type dbSetup struct { name string - create func(*testing.T) Store + create func(testing.TB) Store } type dbTestFunction func(*testing.T, Store) diff --git a/pkg/rpc/response/result/invoke.go b/pkg/rpc/response/result/invoke.go index 83685a7ac..ebd7b699c 100644 --- a/pkg/rpc/response/result/invoke.go +++ b/pkg/rpc/response/result/invoke.go @@ -20,10 +20,11 @@ type Invoke struct { FaultException string Transaction *transaction.Transaction maxIteratorResultItems int + finalize func() } // NewInvoke returns new Invoke structure with the given fields set. -func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResultItems int) *Invoke { +func NewInvoke(vm *vm.VM, finalize func(), script []byte, faultException string, maxIteratorResultItems int) *Invoke { return &Invoke{ State: vm.State().String(), GasConsumed: vm.GasConsumed(), @@ -31,6 +32,7 @@ func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResul Stack: vm.Estack().ToArray(), FaultException: faultException, maxIteratorResultItems: maxIteratorResultItems, + finalize: finalize, } } @@ -55,8 +57,17 @@ type Iterator struct { Truncated bool } +// Finalize releases resources occupied by Iterators created at the script invocation. +// This method will be called automatically on Invoke marshalling. +func (r *Invoke) Finalize() { + if r.finalize != nil { + r.finalize() + } +} + // MarshalJSON implements json.Marshaler. func (r Invoke) MarshalJSON() ([]byte, error) { + defer r.Finalize() var st json.RawMessage arr := make([]json.RawMessage, len(r.Stack)) for i := range arr { diff --git a/pkg/rpc/server/client_test.go b/pkg/rpc/server/client_test.go index 4bff819b0..cd3b3a291 100644 --- a/pkg/rpc/server/client_test.go +++ b/pkg/rpc/server/client_test.go @@ -714,7 +714,7 @@ func TestCreateNEP17TransferTx(t *testing.T) { require.NoError(t, err) require.NoError(t, acc.SignTx(testchain.Network(), tx)) require.NoError(t, chain.VerifyTx(tx)) - v := chain.GetTestVM(trigger.Application, tx, nil) + v, _ := chain.GetTestVM(trigger.Application, tx, nil) v.LoadScriptWithFlags(tx.Script, callflag.All) require.NoError(t, v.Run()) } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index fe7c06487..051dda712 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -624,6 +624,7 @@ func (s *Server) calculateNetworkFee(reqParams request.Params) (interface{}, *re if respErr != nil { return 0, respErr } + res.Finalize() if res.State != "HALT" { cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException) return 0, response.NewRPCError(verificationErr, cause.Error(), cause) @@ -742,7 +743,8 @@ func (s *Server) getNEP17Balance(h util.Uint160, acc util.Uint160, bw *io.BufBin } script := bw.Bytes() tx := &transaction.Transaction{Script: script} - v := s.chain.GetTestVM(trigger.Application, tx, nil) + v, finalize := s.chain.GetTestVM(trigger.Application, tx, nil) + defer finalize() v.GasLimit = core.HeaderVerificationGasLimit v.LoadScriptWithFlags(script, callflag.All) err := v.Run() @@ -1490,7 +1492,6 @@ func (s *Server) invokeContractVerify(reqParams request.Params) (interface{}, *r tx.Signers = []transaction.Signer{{Account: scriptHash}} tx.Scripts = []transaction.Witness{{InvocationScript: invocationScript, VerificationScript: []byte{}}} } - return s.runScriptInVM(trigger.Verification, invocationScript, scriptHash, tx) } @@ -1511,7 +1512,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash } b.Timestamp = hdr.Timestamp + uint64(s.chain.GetConfig().SecondsPerBlock*int(time.Second/time.Millisecond)) - vm := s.chain.GetTestVM(t, tx, b) + vm, finalize := s.chain.GetTestVM(t, tx, b) vm.GasLimit = int64(s.config.MaxGasInvoke) if t == trigger.Verification { // We need this special case because witnesses verification is not the simple System.Contract.Call, @@ -1539,7 +1540,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash if err != nil { faultException = err.Error() } - return result.NewInvoke(vm, script, faultException, s.config.MaxIteratorResultItems), nil + return result.NewInvoke(vm, finalize, script, faultException, s.config.MaxIteratorResultItems), nil } // submitBlock broadcasts a raw block over the NEO network. diff --git a/pkg/services/oracle/response.go b/pkg/services/oracle/response.go index 6b5f9d541..38a69550e 100644 --- a/pkg/services/oracle/response.go +++ b/pkg/services/oracle/response.go @@ -134,16 +134,17 @@ func (o *Oracle) CreateResponseTx(gasForResponse int64, vub uint32, resp *transa } func (o *Oracle) testVerify(tx *transaction.Transaction) (int64, bool) { - v := o.Chain.GetTestVM(trigger.Verification, tx, nil) + v, finalize := o.Chain.GetTestVM(trigger.Verification, tx, nil) v.GasLimit = o.Chain.GetPolicer().GetMaxVerificationGAS() v.LoadScriptWithHash(o.oracleScript, o.oracleHash, callflag.ReadOnly) v.Jump(v.Context(), o.verifyOffset) - ok := isVerifyOk(v) + ok := isVerifyOk(v, finalize) return v.GasConsumed(), ok } -func isVerifyOk(v *vm.VM) bool { +func isVerifyOk(v *vm.VM, finalize func()) bool { + defer finalize() if err := v.Run(); err != nil { return false }