Merge pull request #2193 from nspcc-dev/optimize-find

core: optimise (*MemCachedStorage).Seek
This commit is contained in:
Roman Khimov 2021-10-21 21:20:33 +03:00 committed by GitHub
commit d551439654
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 587 additions and 188 deletions

View file

@ -327,12 +327,12 @@ func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
} }
// GetTestVM implements Blockchainer interface. // GetTestVM implements Blockchainer interface.
func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM { func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) {
panic("TODO") panic("TODO")
} }
// GetStorageItems implements Blockchainer interface. // GetStorageItems implements Blockchainer interface.
func (chain *FakeChain) GetStorageItems(id int32) (map[string]state.StorageItem, error) { func (chain *FakeChain) GetStorageItems(id int32) ([]state.StorageItemWithKey, error) {
panic("TODO") panic("TODO")
} }

View file

@ -490,9 +490,8 @@ func (bc *Blockchain) jumpToStateInternal(p uint32, stage stateJumpStage) error
// Firstly, remove all old genesis-related items. // Firstly, remove all old genesis-related items.
b := bc.dao.Store.Batch() b := bc.dao.Store.Batch()
bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) { bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) {
// Must copy here, #1468. // #1468, but don't need to copy here, because it is done by Store.
key := slice.Copy(k) b.Delete(k)
b.Delete(key)
}) })
b.Put(jumpStageKey, []byte{byte(oldStorageItemsRemoved)}) b.Put(jumpStageKey, []byte{byte(oldStorageItemsRemoved)})
err := bc.dao.Store.PutBatch(b) err := bc.dao.Store.PutBatch(b)
@ -509,14 +508,12 @@ func (bc *Blockchain) jumpToStateInternal(p uint32, stage stateJumpStage) error
if count >= maxStorageBatchSize { if count >= maxStorageBatchSize {
return return
} }
// Must copy here, #1468. // #1468, but don't need to copy here, because it is done by Store.
oldKey := slice.Copy(k) b.Delete(k)
b.Delete(oldKey)
key := make([]byte, len(k)) key := make([]byte, len(k))
key[0] = byte(storage.STStorage) key[0] = byte(storage.STStorage)
copy(key[1:], k[1:]) copy(key[1:], k[1:])
value := slice.Copy(v) b.Put(key, slice.Copy(v))
b.Put(key, value)
count += 2 count += 2
}) })
if count > 0 { if count > 0 {
@ -1039,7 +1036,7 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error
v.LoadToken = contract.LoadToken(systemInterop) v.LoadToken = contract.LoadToken(systemInterop)
v.GasLimit = tx.SystemFee v.GasLimit = tx.SystemFee
err := v.Run() err := systemInterop.Exec()
var faultException string var faultException string
if !v.HasFailed() { if !v.HasFailed() {
_, err := systemInterop.DAO.Persist() _, err := systemInterop.DAO.Persist()
@ -1226,7 +1223,7 @@ func (bc *Blockchain) runPersist(script []byte, block *block.Block, cache dao.DA
v := systemInterop.SpawnVM() v := systemInterop.SpawnVM()
v.LoadScriptWithFlags(script, callflag.All) v.LoadScriptWithFlags(script, callflag.All)
v.SetPriceGetter(systemInterop.GetPrice) v.SetPriceGetter(systemInterop.GetPrice)
if err := v.Run(); err != nil { if err := systemInterop.Exec(); err != nil {
return nil, fmt.Errorf("VM has failed: %w", err) return nil, fmt.Errorf("VM has failed: %w", err)
} else if _, err := systemInterop.DAO.Persist(); err != nil { } else if _, err := systemInterop.DAO.Persist(); err != nil {
return nil, fmt.Errorf("can't save changes: %w", err) return nil, fmt.Errorf("can't save changes: %w", err)
@ -1494,7 +1491,7 @@ func (bc *Blockchain) GetStorageItem(id int32, key []byte) state.StorageItem {
} }
// GetStorageItems returns all storage items for a given contract id. // GetStorageItems returns all storage items for a given contract id.
func (bc *Blockchain) GetStorageItems(id int32) (map[string]state.StorageItem, error) { func (bc *Blockchain) GetStorageItems(id int32) ([]state.StorageItemWithKey, error) {
return bc.dao.GetStorageItems(id) return bc.dao.GetStorageItems(id)
} }
@ -2055,14 +2052,14 @@ func (bc *Blockchain) GetEnrollments() ([]state.Validator, error) {
return bc.contracts.NEO.GetCandidates(bc.dao) return bc.contracts.NEO.GetCandidates(bc.dao)
} }
// GetTestVM returns a VM and a Store setup for a test run of some sort of code. // GetTestVM returns a VM setup for a test run of some sort of code and finalizer function.
func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM { func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) {
d := bc.dao.GetWrapped().(*dao.Simple) d := bc.dao.GetWrapped().(*dao.Simple)
systemInterop := bc.newInteropContext(t, d, b, tx) systemInterop := bc.newInteropContext(t, d, b, tx)
vm := systemInterop.SpawnVM() vm := systemInterop.SpawnVM()
vm.SetPriceGetter(systemInterop.GetPrice) vm.SetPriceGetter(systemInterop.GetPrice)
vm.LoadToken = contract.LoadToken(systemInterop) vm.LoadToken = contract.LoadToken(systemInterop)
return vm return vm, systemInterop.Finalize
} }
// Various witness verification errors. // Various witness verification errors.
@ -2141,7 +2138,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa
if err := bc.InitVerificationVM(vm, interopCtx.GetContract, hash, witness); err != nil { if err := bc.InitVerificationVM(vm, interopCtx.GetContract, hash, witness); err != nil {
return 0, err return 0, err
} }
err := vm.Run() err := interopCtx.Exec()
if vm.HasFailed() { if vm.HasFailed() {
return 0, fmt.Errorf("%w: vm execution has failed: %v", ErrVerificationFailed, err) return 0, fmt.Errorf("%w: vm execution has failed: %v", ErrVerificationFailed, err)
} }

View file

@ -58,8 +58,8 @@ type Blockchainer interface {
GetStateModule() StateRoot GetStateModule() StateRoot
GetStateSyncModule() StateSync GetStateSyncModule() StateSync
GetStorageItem(id int32, key []byte) state.StorageItem GetStorageItem(id int32, key []byte) state.StorageItem
GetStorageItems(id int32) (map[string]state.StorageItem, error) GetStorageItems(id int32) ([]state.StorageItemWithKey, error)
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func())
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
SetOracle(service services.Oracle) SetOracle(service services.Oracle)
mempool.Feer // fee interface mempool.Feer // fee interface

View file

@ -2,6 +2,7 @@ package dao
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -16,7 +17,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
) )
// HasTransaction errors. // HasTransaction errors.
@ -48,8 +48,8 @@ type DAO interface {
GetStateSyncPoint() (uint32, error) GetStateSyncPoint() (uint32, error)
GetStateSyncCurrentBlockHeight() (uint32, error) GetStateSyncCurrentBlockHeight() (uint32, error)
GetStorageItem(id int32, key []byte) state.StorageItem GetStorageItem(id int32, key []byte) state.StorageItem
GetStorageItems(id int32) (map[string]state.StorageItem, error) GetStorageItems(id int32) ([]state.StorageItemWithKey, error)
GetStorageItemsWithPrefix(id int32, prefix []byte) (map[string]state.StorageItem, error) GetStorageItemsWithPrefix(id int32, prefix []byte) ([]state.StorageItemWithKey, error)
GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error)
GetVersion() (string, error) GetVersion() (string, error)
GetWrapped() DAO GetWrapped() DAO
@ -65,6 +65,7 @@ type DAO interface {
PutStorageItem(id int32, key []byte, si state.StorageItem) error PutStorageItem(id int32, key []byte, si state.StorageItem) error
PutVersion(v string) error PutVersion(v string) error
Seek(id int32, prefix []byte, f func(k, v []byte)) Seek(id int32, prefix []byte, f func(k, v []byte))
SeekAsync(ctx context.Context, id int32, prefix []byte) chan storage.KeyValue
StoreAsBlock(block *block.Block, buf *io.BufBinWriter) error StoreAsBlock(block *block.Block, buf *io.BufBinWriter) error
StoreAsCurrentBlock(block *block.Block, buf *io.BufBinWriter) error StoreAsCurrentBlock(block *block.Block, buf *io.BufBinWriter) error
StoreAsTransaction(tx *transaction.Transaction, index uint32, buf *io.BufBinWriter) error StoreAsTransaction(tx *transaction.Transaction, index uint32, buf *io.BufBinWriter) error
@ -313,28 +314,29 @@ func (dao *Simple) DeleteStorageItem(id int32, key []byte) error {
} }
// GetStorageItems returns all storage items for a given id. // GetStorageItems returns all storage items for a given id.
func (dao *Simple) GetStorageItems(id int32) (map[string]state.StorageItem, error) { func (dao *Simple) GetStorageItems(id int32) ([]state.StorageItemWithKey, error) {
return dao.GetStorageItemsWithPrefix(id, nil) return dao.GetStorageItemsWithPrefix(id, nil)
} }
// GetStorageItemsWithPrefix returns all storage items with given id for a // GetStorageItemsWithPrefix returns all storage items with given id for a
// given scripthash. // given scripthash.
func (dao *Simple) GetStorageItemsWithPrefix(id int32, prefix []byte) (map[string]state.StorageItem, error) { func (dao *Simple) GetStorageItemsWithPrefix(id int32, prefix []byte) ([]state.StorageItemWithKey, error) {
var siMap = make(map[string]state.StorageItem) var siArr []state.StorageItemWithKey
saveToMap := func(k, v []byte) { saveToArr := func(k, v []byte) {
// Cut prefix and hash. // Cut prefix and hash.
// Must copy here, #1468. // #1468, but don't need to copy here, because it is done by Store.
key := slice.Copy(k) siArr = append(siArr, state.StorageItemWithKey{
val := slice.Copy(v) Key: k,
siMap[string(key)] = state.StorageItem(val) Item: state.StorageItem(v),
})
} }
dao.Seek(id, prefix, saveToMap) dao.Seek(id, prefix, saveToArr)
return siMap, nil return siArr, nil
} }
// Seek executes f for all items with a given prefix. // Seek executes f for all items with a given prefix.
// If key is to be used outside of f, they must be copied. // If key is to be used outside of f, they may not be copied.
func (dao *Simple) Seek(id int32, prefix []byte, f func(k, v []byte)) { func (dao *Simple) Seek(id int32, prefix []byte, f func(k, v []byte)) {
lookupKey := makeStorageItemKey(id, nil) lookupKey := makeStorageItemKey(id, nil)
if prefix != nil { if prefix != nil {
@ -345,6 +347,16 @@ func (dao *Simple) Seek(id int32, prefix []byte, f func(k, v []byte)) {
}) })
} }
// SeekAsync sends all storage items matching given prefix to a channel and returns
// the channel. Resulting keys and values may not be copied.
func (dao *Simple) SeekAsync(ctx context.Context, id int32, prefix []byte) chan storage.KeyValue {
lookupKey := makeStorageItemKey(id, nil)
if prefix != nil {
lookupKey = append(lookupKey, prefix...)
}
return dao.Store.SeekAsync(ctx, lookupKey, true)
}
// makeStorageItemKey returns a key used to store StorageItem in the DB. // makeStorageItemKey returns a key used to store StorageItem in the DB.
func makeStorageItemKey(id int32, key []byte) []byte { func makeStorageItemKey(id int32, key []byte) []byte {
// 1 for prefix + 4 for Uint32 + len(key) for key // 1 for prefix + 4 for Uint32 + len(key) for key

View file

@ -1,6 +1,7 @@
package interop package interop
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -47,6 +48,7 @@ type Context struct {
Log *zap.Logger Log *zap.Logger
VM *vm.VM VM *vm.VM
Functions []Function Functions []Function
cancelFuncs []context.CancelFunc
getContract func(dao.DAO, util.Uint160) (*state.Contract, error) getContract func(dao.DAO, util.Uint160) (*state.Contract, error)
baseExecFee int64 baseExecFee int64
} }
@ -285,3 +287,25 @@ func (ic *Context) SpawnVM() *vm.VM {
ic.VM = v ic.VM = v
return v return v
} }
// RegisterCancelFunc adds given function to the list of functions to be called after VM
// finishes script execution.
func (ic *Context) RegisterCancelFunc(f context.CancelFunc) {
if f != nil {
ic.cancelFuncs = append(ic.cancelFuncs, f)
}
}
// Finalize calls all registered cancel functions to release the occupied resources.
func (ic *Context) Finalize() {
for _, f := range ic.cancelFuncs {
f()
}
ic.cancelFuncs = nil
}
// Exec executes loaded VM script and calls registered finalizers to release the occupied resources.
func (ic *Context) Exec() error {
defer ic.Finalize()
return ic.VM.Run()
}

View file

@ -1,6 +1,10 @@
package storage package storage
import "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" import (
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
)
// Storage iterator options. // Storage iterator options.
const ( const (
@ -18,44 +22,45 @@ const (
// Iterator is an iterator state representation. // Iterator is an iterator state representation.
type Iterator struct { type Iterator struct {
m []stackitem.MapElement seekCh chan storage.KeyValue
curr storage.KeyValue
next bool
opts int64 opts int64
index int prefix []byte
prefixSize int
} }
// NewIterator creates a new Iterator with given options for a given map. // NewIterator creates a new Iterator with given options for a given channel of store.Seek results.
func NewIterator(m *stackitem.Map, prefix int, opts int64) *Iterator { func NewIterator(seekCh chan storage.KeyValue, prefix []byte, opts int64) *Iterator {
return &Iterator{ return &Iterator{
m: m.Value().([]stackitem.MapElement), seekCh: seekCh,
opts: opts, opts: opts,
index: -1, prefix: slice.Copy(prefix),
prefixSize: prefix,
} }
} }
// Next advances the iterator and returns true if Value can be called at the // Next advances the iterator and returns true if Value can be called at the
// current position. // current position.
func (s *Iterator) Next() bool { func (s *Iterator) Next() bool {
if s.index < len(s.m) { s.curr, s.next = <-s.seekCh
s.index++ return s.next
}
return s.index < len(s.m)
} }
// Value returns current iterators value (exact type depends on options this // Value returns current iterators value (exact type depends on options this
// iterator was created with). // iterator was created with).
func (s *Iterator) Value() stackitem.Item { func (s *Iterator) Value() stackitem.Item {
key := s.m[s.index].Key.Value().([]byte) if !s.next {
if s.opts&FindRemovePrefix != 0 { panic("iterator index out of range")
key = key[s.prefixSize:] }
key := s.curr.Key
if s.opts&FindRemovePrefix == 0 {
key = append(s.prefix, key...)
} }
if s.opts&FindKeysOnly != 0 { if s.opts&FindKeysOnly != 0 {
return stackitem.NewByteArray(key) return stackitem.NewByteArray(key)
} }
value := s.m[s.index].Value value := stackitem.Item(stackitem.NewByteArray(s.curr.Value))
if s.opts&FindDeserialize != 0 { if s.opts&FindDeserialize != 0 {
bs := s.m[s.index].Value.Value().([]byte) bs := s.curr.Value
var err error var err error
value, err = stackitem.Deserialize(bs) value, err = stackitem.Deserialize(bs)
if err != nil { if err != nil {

View file

@ -1,13 +1,12 @@
package core package core
import ( import (
"bytes" "context"
"crypto/elliptic" "crypto/elliptic"
"errors" "errors"
"fmt" "fmt"
"math" "math"
"math/big" "math/big"
"sort"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
@ -188,30 +187,13 @@ func storageFind(ic *interop.Context) error {
if opts&istorage.FindDeserialize == 0 && (opts&istorage.FindPick0 != 0 || opts&istorage.FindPick1 != 0) { if opts&istorage.FindDeserialize == 0 && (opts&istorage.FindPick0 != 0 || opts&istorage.FindPick1 != 0) {
return fmt.Errorf("%w: PickN is specified without Deserialize", errFindInvalidOptions) return fmt.Errorf("%w: PickN is specified without Deserialize", errFindInvalidOptions)
} }
siMap, err := ic.DAO.GetStorageItemsWithPrefix(stc.ID, prefix) // Items in seekres should be sorted by key, but GetStorageItemsWithPrefix returns
if err != nil { // sorted items, so no need to sort them one more time.
return err ctx, cancel := context.WithCancel(context.Background())
} seekres := ic.DAO.SeekAsync(ctx, stc.ID, prefix)
item := istorage.NewIterator(seekres, prefix, opts)
arr := make([]stackitem.MapElement, 0, len(siMap))
for k, v := range siMap {
keycopy := make([]byte, len(k)+len(prefix))
copy(keycopy, prefix)
copy(keycopy[len(prefix):], k)
arr = append(arr, stackitem.MapElement{
Key: stackitem.NewByteArray(keycopy),
Value: stackitem.NewByteArray(v),
})
}
sort.Slice(arr, func(i, j int) bool {
k1 := arr[i].Key.Value().([]byte)
k2 := arr[j].Key.Value().([]byte)
return bytes.Compare(k1, k2) == -1
})
filteredMap := stackitem.NewMapWithValue(arr)
item := istorage.NewIterator(filteredMap, len(prefix), opts)
ic.VM.Estack().PushItem(stackitem.NewInterop(item)) ic.VM.Estack().PushItem(stackitem.NewInterop(item))
ic.RegisterCancelFunc(cancel)
return nil return nil
} }

View file

@ -2,6 +2,7 @@ package core
import ( import (
"errors" "errors"
"fmt"
"math" "math"
"math/big" "math/big"
"testing" "testing"
@ -336,11 +337,11 @@ func TestStorageDelete(t *testing.T) {
} }
func BenchmarkStorageFind(b *testing.B) { func BenchmarkStorageFind(b *testing.B) {
for count := 10; count <= 10000; count *= 10 {
b.Run(fmt.Sprintf("%dElements", count), func(b *testing.B) {
v, contractState, context, chain := createVMAndContractState(b) v, contractState, context, chain := createVMAndContractState(b)
require.NoError(b, chain.contracts.Management.PutContractState(chain.dao, contractState)) require.NoError(b, chain.contracts.Management.PutContractState(chain.dao, contractState))
const count = 100
items := make(map[string]state.StorageItem) items := make(map[string]state.StorageItem)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
items["abc"+random.String(10)] = random.Bytes(10) items["abc"+random.String(10)] = random.Bytes(10)
@ -349,6 +350,9 @@ func BenchmarkStorageFind(b *testing.B) {
require.NoError(b, context.DAO.PutStorageItem(contractState.ID, []byte(k), v)) require.NoError(b, context.DAO.PutStorageItem(contractState.ID, []byte(k), v))
require.NoError(b, context.DAO.PutStorageItem(contractState.ID+1, []byte(k), v)) require.NoError(b, context.DAO.PutStorageItem(contractState.ID+1, []byte(k), v))
} }
changes, err := context.DAO.Persist()
require.NoError(b, err)
require.NotEqual(b, 0, changes)
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()
@ -362,6 +366,72 @@ func BenchmarkStorageFind(b *testing.B) {
if err != nil { if err != nil {
b.FailNow() b.FailNow()
} }
b.StopTimer()
context.Finalize()
}
})
}
}
func BenchmarkStorageFindIteratorNext(b *testing.B) {
for count := 10; count <= 10000; count *= 10 {
cases := map[string]int{
"Pick1": 1,
"PickHalf": count / 2,
"PickAll": count,
}
b.Run(fmt.Sprintf("%dElements", count), func(b *testing.B) {
for name, last := range cases {
b.Run(name, func(b *testing.B) {
v, contractState, context, chain := createVMAndContractState(b)
require.NoError(b, chain.contracts.Management.PutContractState(chain.dao, contractState))
items := make(map[string]state.StorageItem)
for i := 0; i < count; i++ {
items["abc"+random.String(10)] = random.Bytes(10)
}
for k, v := range items {
require.NoError(b, context.DAO.PutStorageItem(contractState.ID, []byte(k), v))
require.NoError(b, context.DAO.PutStorageItem(contractState.ID+1, []byte(k), v))
}
changes, err := context.DAO.Persist()
require.NoError(b, err)
require.NotEqual(b, 0, changes)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
v.Estack().PushVal(istorage.FindDefault)
v.Estack().PushVal("abc")
v.Estack().PushVal(stackitem.NewInterop(&StorageContext{ID: contractState.ID}))
b.StartTimer()
err := storageFind(context)
b.StopTimer()
if err != nil {
b.FailNow()
}
res := context.VM.Estack().Pop().Item()
for i := 0; i < last; i++ {
context.VM.Estack().PushVal(res)
b.StartTimer()
require.NoError(b, iterator.Next(context))
b.StopTimer()
require.True(b, context.VM.Estack().Pop().Bool())
}
context.VM.Estack().PushVal(res)
require.NoError(b, iterator.Next(context))
actual := context.VM.Estack().Pop().Bool()
if last == count {
require.False(b, actual)
} else {
require.True(b, actual)
}
context.Finalize()
}
})
}
})
} }
} }

View file

@ -257,17 +257,22 @@ func (s *Designate) GetDesignatedByRole(d dao.DAO, r noderoles.Role, index uint3
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
var ns NodeList var (
var bestIndex uint32 ns NodeList
var resSi state.StorageItem bestIndex uint32
for k, si := range kvs { resSi state.StorageItem
if len(k) < 4 { )
// kvs are sorted by key (BE index bytes) in ascending way, so iterate backwards to get the latest designated.
for i := len(kvs) - 1; i >= 0; i-- {
kv := kvs[i]
if len(kv.Key) < 4 {
continue continue
} }
siInd := binary.BigEndian.Uint32([]byte(k)) siInd := binary.BigEndian.Uint32(kv.Key)
if (resSi == nil || siInd > bestIndex) && siInd <= index { if siInd <= index {
bestIndex = siInd bestIndex = siInd
resSi = si resSi = kv.Item
break
} }
} }
if resSi != nil { if resSi != nil {

View file

@ -391,12 +391,12 @@ func (m *Management) Destroy(d dao.DAO, hash util.Uint160) error {
if err != nil { if err != nil {
return err return err
} }
siMap, err := d.GetStorageItems(contract.ID) siArr, err := d.GetStorageItems(contract.ID)
if err != nil { if err != nil {
return err return err
} }
for k := range siMap { for _, kv := range siArr {
err := d.DeleteStorageItem(contract.ID, []byte(k)) err := d.DeleteStorageItem(contract.ID, []byte(kv.Key))
if err != nil { if err != nil {
return err return err
} }

View file

@ -471,24 +471,20 @@ func (n *NEO) getGASPerBlock(ic *interop.Context, _ []stackitem.Item) stackitem.
} }
func (n *NEO) getSortedGASRecordFromDAO(d dao.DAO) (gasRecord, error) { func (n *NEO) getSortedGASRecordFromDAO(d dao.DAO) (gasRecord, error) {
grMap, err := d.GetStorageItemsWithPrefix(n.ID, []byte{prefixGASPerBlock}) grArr, err := d.GetStorageItemsWithPrefix(n.ID, []byte{prefixGASPerBlock})
if err != nil { if err != nil {
return gasRecord{}, fmt.Errorf("failed to get gas records from storage: %w", err) return gasRecord{}, fmt.Errorf("failed to get gas records from storage: %w", err)
} }
var ( var gr = make(gasRecord, len(grArr))
i int for i, kv := range grArr {
gr = make(gasRecord, len(grMap)) indexBytes, gasValue := kv.Key, kv.Item
)
for indexBytes, gasValue := range grMap {
gr[i] = gasIndexPair{ gr[i] = gasIndexPair{
Index: binary.BigEndian.Uint32([]byte(indexBytes)), Index: binary.BigEndian.Uint32([]byte(indexBytes)),
GASPerBlock: *bigint.FromBytes(gasValue), GASPerBlock: *bigint.FromBytes(gasValue),
} }
i++
} }
sort.Slice(gr, func(i, j int) bool { // GAS records should be sorted by index, but GetStorageItemsWithPrefix returns
return gr[i].Index < gr[j].Index // values sorted by BE bytes of index, so we're OK with that.
})
return gr, nil return gr, nil
} }
@ -836,21 +832,21 @@ func (n *NEO) ModifyAccountVotes(acc *state.NEOBalance, d dao.DAO, value *big.In
} }
func (n *NEO) getCandidates(d dao.DAO, sortByKey bool) ([]keyWithVotes, error) { func (n *NEO) getCandidates(d dao.DAO, sortByKey bool) ([]keyWithVotes, error) {
siMap, err := d.GetStorageItemsWithPrefix(n.ID, []byte{prefixCandidate}) siArr, err := d.GetStorageItemsWithPrefix(n.ID, []byte{prefixCandidate})
if err != nil { if err != nil {
return nil, err return nil, err
} }
arr := make([]keyWithVotes, 0, len(siMap)) arr := make([]keyWithVotes, 0, len(siArr))
for key, si := range siMap { for _, kv := range siArr {
c := new(candidate).FromBytes(si) c := new(candidate).FromBytes(kv.Item)
if c.Registered { if c.Registered {
arr = append(arr, keyWithVotes{Key: key, Votes: &c.Votes}) arr = append(arr, keyWithVotes{Key: string(kv.Key), Votes: &c.Votes})
} }
} }
if sortByKey { if !sortByKey {
// Sort by serialized key bytes (that's the way keys are stored and retrieved from the storage by default). // sortByKey assumes to sort by serialized key bytes (that's the way keys
sort.Slice(arr, func(i, j int) bool { return strings.Compare(arr[i].Key, arr[j].Key) == -1 }) // are stored and retrieved from the storage by default). Otherwise, need
} else { // to sort using big.Int comparator.
sort.Slice(arr, func(i, j int) bool { sort.Slice(arr, func(i, j int) bool {
// The most-voted validators should end up in the front of the list. // The most-voted validators should end up in the front of the list.
cmp := arr[i].Votes.Cmp(arr[j].Votes) cmp := arr[i].Votes.Cmp(arr[j].Votes)

View file

@ -483,21 +483,21 @@ func (o *Oracle) getOriginalTxID(d dao.DAO, tx *transaction.Transaction) util.Ui
// getRequests returns all requests which have not been finished yet. // getRequests returns all requests which have not been finished yet.
func (o *Oracle) getRequests(d dao.DAO) (map[uint64]*state.OracleRequest, error) { func (o *Oracle) getRequests(d dao.DAO) (map[uint64]*state.OracleRequest, error) {
m, err := d.GetStorageItemsWithPrefix(o.ID, prefixRequest) arr, err := d.GetStorageItemsWithPrefix(o.ID, prefixRequest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
reqs := make(map[uint64]*state.OracleRequest, len(m)) reqs := make(map[uint64]*state.OracleRequest, len(arr))
for k, si := range m { for _, kv := range arr {
if len(k) != 8 { if len(kv.Key) != 8 {
return nil, errors.New("invalid request ID") return nil, errors.New("invalid request ID")
} }
req := new(state.OracleRequest) req := new(state.OracleRequest)
err = stackitem.DeserializeConvertible(si, req) err = stackitem.DeserializeConvertible(kv.Item, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
id := binary.BigEndian.Uint64([]byte(k)) id := binary.BigEndian.Uint64([]byte(kv.Key))
reqs[id] = req reqs[id] = req
} }
return reqs, nil return reqs, nil

View file

@ -162,20 +162,20 @@ func (p *Policy) PostPersist(ic *interop.Context) error {
p.storagePrice = uint32(getIntWithKey(p.ID, ic.DAO, storagePriceKey)) p.storagePrice = uint32(getIntWithKey(p.ID, ic.DAO, storagePriceKey))
p.blockedAccounts = make([]util.Uint160, 0) p.blockedAccounts = make([]util.Uint160, 0)
siMap, err := ic.DAO.GetStorageItemsWithPrefix(p.ID, []byte{blockedAccountPrefix}) siArr, err := ic.DAO.GetStorageItemsWithPrefix(p.ID, []byte{blockedAccountPrefix})
if err != nil { if err != nil {
return fmt.Errorf("failed to get blocked accounts from storage: %w", err) return fmt.Errorf("failed to get blocked accounts from storage: %w", err)
} }
for key := range siMap { for _, kv := range siArr {
hash, err := util.Uint160DecodeBytesBE([]byte(key)) hash, err := util.Uint160DecodeBytesBE([]byte(kv.Key))
if err != nil { if err != nil {
return fmt.Errorf("failed to decode blocked account hash: %w", err) return fmt.Errorf("failed to decode blocked account hash: %w", err)
} }
p.blockedAccounts = append(p.blockedAccounts, hash) p.blockedAccounts = append(p.blockedAccounts, hash)
} }
sort.Slice(p.blockedAccounts, func(i, j int) bool { // blockedAccounts should be sorted by account BE bytes, but GetStorageItemsWithPrefix
return p.blockedAccounts[i].Less(p.blockedAccounts[j]) // returns values sorted by key (which is account's BE bytes), so don't need to sort
}) // one more time.
p.isValid = true p.isValid = true
return nil return nil

View file

@ -2,3 +2,9 @@ package state
// StorageItem is the value to be stored with read-only flag. // StorageItem is the value to be stored with read-only flag.
type StorageItem []byte type StorageItem []byte
// StorageItemWithKey is a storage item with corresponding key.
type StorageItemWithKey struct {
Key []byte
Item StorageItem
}

View file

@ -13,7 +13,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -147,9 +146,8 @@ func (s *Module) CleanStorage() error {
// //
b := s.Store.Batch() b := s.Store.Batch()
s.Store.Seek([]byte{byte(storage.DataMPT)}, func(k, _ []byte) { s.Store.Seek([]byte{byte(storage.DataMPT)}, func(k, _ []byte) {
// Must copy here, #1468. // #1468, but don't need to copy here, because it is done by Store.
key := slice.Copy(k) b.Delete(k)
b.Delete(key)
}) })
err = s.Store.PutBatch(b) err = s.Store.PutBatch(b)
if err != nil { if err != nil {

View file

@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func newBadgerDBForTesting(t *testing.T) Store { func newBadgerDBForTesting(t testing.TB) Store {
bdbDir := t.TempDir() bdbDir := t.TempDir()
dbConfig := DBConfiguration{ dbConfig := DBConfiguration{
Type: "badgerdb", Type: "badgerdb",

View file

@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func newBoltStoreForTesting(t *testing.T) Store { func newBoltStoreForTesting(t testing.TB) Store {
d := t.TempDir() d := t.TempDir()
testFileName := path.Join(d, "test_bolt_db") testFileName := path.Join(d, "test_bolt_db")
boltDBStore, err := NewBoltDBStore(BoltDBOptions{FilePath: testFileName}) boltDBStore, err := NewBoltDBStore(BoltDBOptions{FilePath: testFileName})

View file

@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func newLevelDBForTesting(t *testing.T) Store { func newLevelDBForTesting(t testing.TB) Store {
ldbDir := t.TempDir() ldbDir := t.TempDir()
dbConfig := DBConfiguration{ dbConfig := DBConfiguration{
Type: "leveldb", Type: "leveldb",

View file

@ -1,6 +1,14 @@
package storage package storage
import "sync" import (
"bytes"
"context"
"sort"
"strings"
"sync"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
)
// MemCachedStore is a wrapper around persistent store that caches all changes // MemCachedStore is a wrapper around persistent store that caches all changes
// being made for them to be later flushed in one batch. // being made for them to be later flushed in one batch.
@ -18,14 +26,20 @@ type (
KeyValue struct { KeyValue struct {
Key []byte Key []byte
Value []byte Value []byte
}
// KeyValueExists represents key-value pair with indicator whether the item
// exists in the persistent storage.
KeyValueExists struct {
KeyValue
Exists bool Exists bool
} }
// MemBatch represents a changeset to be persisted. // MemBatch represents a changeset to be persisted.
MemBatch struct { MemBatch struct {
Put []KeyValue Put []KeyValueExists
Deleted []KeyValue Deleted []KeyValueExists
} }
) )
@ -58,18 +72,18 @@ func (s *MemCachedStore) GetBatch() *MemBatch {
var b MemBatch var b MemBatch
b.Put = make([]KeyValue, 0, len(s.mem)) b.Put = make([]KeyValueExists, 0, len(s.mem))
for k, v := range s.mem { for k, v := range s.mem {
key := []byte(k) key := []byte(k)
_, err := s.ps.Get(key) _, err := s.ps.Get(key)
b.Put = append(b.Put, KeyValue{Key: key, Value: v, Exists: err == nil}) b.Put = append(b.Put, KeyValueExists{KeyValue: KeyValue{Key: key, Value: v}, Exists: err == nil})
} }
b.Deleted = make([]KeyValue, 0, len(s.del)) b.Deleted = make([]KeyValueExists, 0, len(s.del))
for k := range s.del { for k := range s.del {
key := []byte(k) key := []byte(k)
_, err := s.ps.Get(key) _, err := s.ps.Get(key)
b.Deleted = append(b.Deleted, KeyValue{Key: key, Exists: err == nil}) b.Deleted = append(b.Deleted, KeyValueExists{KeyValue: KeyValue{Key: key}, Exists: err == nil})
} }
return &b return &b
@ -77,21 +91,130 @@ func (s *MemCachedStore) GetBatch() *MemBatch {
// Seek implements the Store interface. // Seek implements the Store interface.
func (s *MemCachedStore) Seek(key []byte, f func(k, v []byte)) { func (s *MemCachedStore) Seek(key []byte, f func(k, v []byte)) {
s.mut.RLock() s.seek(context.Background(), key, false, f)
defer s.mut.RUnlock() }
s.MemoryStore.seek(key, f)
s.ps.Seek(key, func(k, v []byte) { // SeekAsync returns non-buffered channel with matching KeyValue pairs. Key and
elem := string(k) // value slices may not be copied and may be modified. SeekAsync can guarantee
// If it's in mem, we already called f() for it in MemoryStore.Seek(). // that key-value items are sorted by key in ascending way.
_, present := s.mem[elem] func (s *MemCachedStore) SeekAsync(ctx context.Context, key []byte, cutPrefix bool) chan KeyValue {
if !present { res := make(chan KeyValue)
// If it's in del, we shouldn't be calling f() anyway. go func() {
_, present = s.del[elem] s.seek(ctx, key, cutPrefix, func(k, v []byte) {
} res <- KeyValue{
if !present { Key: k,
f(k, v) Value: v,
} }
}) })
close(res)
}()
return res
}
func (s *MemCachedStore) seek(ctx context.Context, key []byte, cutPrefix bool, f func(k, v []byte)) {
// Create memory store `mem` and `del` snapshot not to hold the lock.
var memRes []KeyValueExists
sk := string(key)
s.mut.RLock()
for k, v := range s.MemoryStore.mem {
if strings.HasPrefix(k, sk) {
memRes = append(memRes, KeyValueExists{
KeyValue: KeyValue{
Key: []byte(k),
Value: v,
},
Exists: true,
})
}
}
for k := range s.MemoryStore.del {
if strings.HasPrefix(k, sk) {
memRes = append(memRes, KeyValueExists{
KeyValue: KeyValue{
Key: []byte(k),
},
})
}
}
ps := s.ps
s.mut.RUnlock()
// Sort memRes items for further comparison with ps items.
sort.Slice(memRes, func(i, j int) bool {
return bytes.Compare(memRes[i].Key, memRes[j].Key) < 0
})
var (
done bool
iMem int
kvMem KeyValueExists
haveMem bool
)
if iMem < len(memRes) {
kvMem = memRes[iMem]
haveMem = true
iMem++
}
// Merge results of seek operations in ascending order.
ps.Seek(key, func(k, v []byte) {
if done {
return
}
kvPs := KeyValue{
Key: slice.Copy(k),
Value: slice.Copy(v),
}
loop:
for {
select {
case <-ctx.Done():
done = true
break loop
default:
var isMem = haveMem && (bytes.Compare(kvMem.Key, kvPs.Key) < 0)
if isMem {
if kvMem.Exists {
if cutPrefix {
kvMem.Key = kvMem.Key[len(key):]
}
f(kvMem.Key, kvMem.Value)
}
if iMem < len(memRes) {
kvMem = memRes[iMem]
haveMem = true
iMem++
} else {
haveMem = false
}
} else {
if !bytes.Equal(kvMem.Key, kvPs.Key) {
if cutPrefix {
kvPs.Key = kvPs.Key[len(key):]
}
f(kvPs.Key, kvPs.Value)
}
break loop
}
}
}
})
if !done && haveMem {
loop:
for i := iMem - 1; i < len(memRes); i++ {
select {
case <-ctx.Done():
break loop
default:
kvMem = memRes[i]
if kvMem.Exists {
if cutPrefix {
kvMem.Key = kvMem.Key[len(key):]
}
f(kvMem.Key, kvMem.Value)
}
}
}
}
} }
// Persist flushes all the MemoryStore contents into the (supposedly) persistent // Persist flushes all the MemoryStore contents into the (supposedly) persistent

View file

@ -1,8 +1,13 @@
package storage package storage
import ( import (
"bytes"
"fmt"
"sort"
"testing" "testing"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -16,7 +21,7 @@ func testMemCachedStorePersist(t *testing.T, ps Store) {
assert.Equal(t, 0, c) assert.Equal(t, 0, c)
// persisting one key should result in one key in ps and nothing in ts // persisting one key should result in one key in ps and nothing in ts
assert.NoError(t, ts.Put([]byte("key"), []byte("value"))) assert.NoError(t, ts.Put([]byte("key"), []byte("value")))
checkBatch(t, ts, []KeyValue{{Key: []byte("key"), Value: []byte("value")}}, nil) checkBatch(t, ts, []KeyValueExists{{KeyValue: KeyValue{Key: []byte("key"), Value: []byte("value")}}}, nil)
c, err = ts.Persist() c, err = ts.Persist()
checkBatch(t, ts, nil, nil) checkBatch(t, ts, nil, nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
@ -35,9 +40,9 @@ func testMemCachedStorePersist(t *testing.T, ps Store) {
v, err = ps.Get([]byte("key2")) v, err = ps.Get([]byte("key2"))
assert.Equal(t, ErrKeyNotFound, err) assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v) assert.Equal(t, []byte(nil), v)
checkBatch(t, ts, []KeyValue{ checkBatch(t, ts, []KeyValueExists{
{Key: []byte("key"), Value: []byte("newvalue"), Exists: true}, {KeyValue: KeyValue{Key: []byte("key"), Value: []byte("newvalue")}, Exists: true},
{Key: []byte("key2"), Value: []byte("value2")}, {KeyValue: KeyValue{Key: []byte("key2"), Value: []byte("value2")}},
}, nil) }, nil)
// two keys should be persisted (one overwritten and one new) and // two keys should be persisted (one overwritten and one new) and
// available in the ps // available in the ps
@ -65,7 +70,7 @@ func testMemCachedStorePersist(t *testing.T, ps Store) {
// test persisting deletions // test persisting deletions
err = ts.Delete([]byte("key")) err = ts.Delete([]byte("key"))
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
checkBatch(t, ts, nil, []KeyValue{{Key: []byte("key"), Exists: true}}) checkBatch(t, ts, nil, []KeyValueExists{{KeyValue: KeyValue{Key: []byte("key")}, Exists: true}})
c, err = ts.Persist() c, err = ts.Persist()
checkBatch(t, ts, nil, nil) checkBatch(t, ts, nil, nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
@ -78,7 +83,7 @@ func testMemCachedStorePersist(t *testing.T, ps Store) {
assert.Equal(t, []byte("value2"), v) assert.Equal(t, []byte("value2"), v)
} }
func checkBatch(t *testing.T, ts *MemCachedStore, put []KeyValue, del []KeyValue) { func checkBatch(t *testing.T, ts *MemCachedStore, put []KeyValueExists, del []KeyValueExists) {
b := ts.GetBatch() b := ts.GetBatch()
assert.Equal(t, len(put), len(b.Put), "wrong number of put elements in a batch") assert.Equal(t, len(put), len(b.Put), "wrong number of put elements in a batch")
assert.Equal(t, len(del), len(b.Deleted), "wrong number of deleted elements in a batch") assert.Equal(t, len(del), len(b.Deleted), "wrong number of deleted elements in a batch")
@ -174,7 +179,82 @@ func TestCachedSeek(t *testing.T) {
} }
} }
func newMemCachedStoreForTesting(t *testing.T) Store { func benchmarkCachedSeek(t *testing.B, ps Store, psElementsCount, tsElementsCount int) {
var (
searchPrefix = []byte{1}
badPrefix = []byte{2}
lowerPrefixGood = append(searchPrefix, 1)
lowerPrefixBad = append(badPrefix, 1)
deletedPrefixGood = append(searchPrefix, 2)
deletedPrefixBad = append(badPrefix, 2)
updatedPrefixGood = append(searchPrefix, 3)
updatedPrefixBad = append(badPrefix, 3)
ts = NewMemCachedStore(ps)
)
for i := 0; i < psElementsCount; i++ {
// lower KVs with matching prefix that should be found
require.NoError(t, ps.Put(append(lowerPrefixGood, random.Bytes(10)...), []byte("value")))
// lower KVs with non-matching prefix that shouldn't be found
require.NoError(t, ps.Put(append(lowerPrefixBad, random.Bytes(10)...), []byte("value")))
// deleted KVs with matching prefix that shouldn't be found
key := append(deletedPrefixGood, random.Bytes(10)...)
require.NoError(t, ps.Put(key, []byte("deleted")))
if i < tsElementsCount {
require.NoError(t, ts.Delete(key))
}
// deleted KVs with non-matching prefix that shouldn't be found
key = append(deletedPrefixBad, random.Bytes(10)...)
require.NoError(t, ps.Put(key, []byte("deleted")))
if i < tsElementsCount {
require.NoError(t, ts.Delete(key))
}
// updated KVs with matching prefix that should be found
key = append(updatedPrefixGood, random.Bytes(10)...)
require.NoError(t, ps.Put(key, []byte("stub")))
if i < tsElementsCount {
require.NoError(t, ts.Put(key, []byte("updated")))
}
// updated KVs with non-matching prefix that shouldn't be found
key = append(updatedPrefixBad, random.Bytes(10)...)
require.NoError(t, ps.Put(key, []byte("stub")))
if i < tsElementsCount {
require.NoError(t, ts.Put(key, []byte("updated")))
}
}
t.ReportAllocs()
t.ResetTimer()
for n := 0; n < t.N; n++ {
ts.Seek(searchPrefix, func(k, v []byte) {})
}
t.StopTimer()
}
func BenchmarkCachedSeek(t *testing.B) {
var stores = map[string]func(testing.TB) Store{
"MemPS": func(t testing.TB) Store {
return NewMemoryStore()
},
"BoltPS": newBoltStoreForTesting,
"LevelPS": newLevelDBForTesting,
}
for psName, newPS := range stores {
for psCount := 100; psCount <= 10000; psCount *= 10 {
for tsCount := 10; tsCount <= psCount; tsCount *= 10 {
t.Run(fmt.Sprintf("%s_%dTSItems_%dPSItems", psName, tsCount, psCount), func(t *testing.B) {
ps := newPS(t)
benchmarkCachedSeek(t, ps, psCount, tsCount)
ps.Close()
})
}
}
}
}
func newMemCachedStoreForTesting(t testing.TB) Store {
return NewMemCachedStore(NewMemoryStore()) return NewMemCachedStore(NewMemoryStore())
} }
@ -242,3 +322,52 @@ func TestMemCachedPersistFailing(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, b1, res) require.Equal(t, b1, res)
} }
func TestCachedSeekSorting(t *testing.T) {
var (
// Given this prefix...
goodPrefix = []byte{1}
// these pairs should be found...
lowerKVs = []kvSeen{
{[]byte{1, 2, 3}, []byte("bra"), false},
{[]byte{1, 2, 5}, []byte("bar"), false},
{[]byte{1, 3, 3}, []byte("bra"), false},
{[]byte{1, 3, 5}, []byte("bra"), false},
}
// and these should be not.
deletedKVs = []kvSeen{
{[]byte{1, 7, 3}, []byte("pow"), false},
{[]byte{1, 7, 4}, []byte("qaz"), false},
}
// and these should be not.
updatedKVs = []kvSeen{
{[]byte{1, 2, 4}, []byte("zaq"), false},
{[]byte{1, 2, 6}, []byte("zaq"), false},
{[]byte{1, 3, 2}, []byte("wop"), false},
{[]byte{1, 3, 4}, []byte("zaq"), false},
}
ps = NewMemoryStore()
ts = NewMemCachedStore(ps)
)
for _, v := range lowerKVs {
require.NoError(t, ps.Put(v.key, v.val))
}
for _, v := range deletedKVs {
require.NoError(t, ps.Put(v.key, v.val))
require.NoError(t, ts.Delete(v.key))
}
for _, v := range updatedKVs {
require.NoError(t, ps.Put(v.key, []byte("stub")))
require.NoError(t, ts.Put(v.key, v.val))
}
var foundKVs []kvSeen
ts.Seek(goodPrefix, func(k, v []byte) {
foundKVs = append(foundKVs, kvSeen{key: slice.Copy(k), val: slice.Copy(v)})
})
assert.Equal(t, len(foundKVs), len(lowerKVs)+len(updatedKVs))
expected := append(lowerKVs, updatedKVs...)
sort.Slice(expected, func(i, j int) bool {
return bytes.Compare(expected[i].key, expected[j].key) < 0
})
require.Equal(t, expected, foundKVs)
}

View file

@ -1,6 +1,8 @@
package storage package storage
import ( import (
"bytes"
"sort"
"strings" "strings"
"sync" "sync"
@ -128,11 +130,21 @@ func (s *MemoryStore) SeekAll(key []byte, f func(k, v []byte)) {
// seek is an internal unlocked implementation of Seek. // seek is an internal unlocked implementation of Seek.
func (s *MemoryStore) seek(key []byte, f func(k, v []byte)) { func (s *MemoryStore) seek(key []byte, f func(k, v []byte)) {
sk := string(key) sk := string(key)
var memList []KeyValue
for k, v := range s.mem { for k, v := range s.mem {
if strings.HasPrefix(k, sk) { if strings.HasPrefix(k, sk) {
f([]byte(k), v) memList = append(memList, KeyValue{
Key: []byte(k),
Value: v,
})
} }
} }
sort.Slice(memList, func(i, j int) bool {
return bytes.Compare(memList[i].Key, memList[j].Key) < 0
})
for _, kv := range memList {
f(kv.Key, kv.Value)
}
} }
// Batch implements the Batch interface and returns a compatible Batch. // Batch implements the Batch interface and returns a compatible Batch.

View file

@ -1,9 +1,35 @@
package storage package storage
import ( import (
"fmt"
"testing" "testing"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/stretchr/testify/require"
) )
func newMemoryStoreForTesting(t *testing.T) Store { func newMemoryStoreForTesting(t testing.TB) Store {
return NewMemoryStore() return NewMemoryStore()
} }
func BenchmarkMemorySeek(t *testing.B) {
for count := 10; count <= 10000; count *= 10 {
t.Run(fmt.Sprintf("%dElements", count), func(t *testing.B) {
ms := NewMemoryStore()
var (
searchPrefix = []byte{1}
badPrefix = []byte{2}
)
for i := 0; i < count; i++ {
require.NoError(t, ms.Put(append(searchPrefix, random.Bytes(10)...), random.Bytes(10)))
require.NoError(t, ms.Put(append(badPrefix, random.Bytes(10)...), random.Bytes(10)))
}
t.ReportAllocs()
t.ResetTimer()
for n := 0; n < t.N; n++ {
ms.Seek(searchPrefix, func(k, v []byte) {})
}
})
}
}

View file

@ -12,7 +12,7 @@ type mockedRedisStore struct {
mini *miniredis.Miniredis mini *miniredis.Miniredis
} }
func prepareRedisMock(t *testing.T) (*miniredis.Miniredis, *RedisStore) { func prepareRedisMock(t testing.TB) (*miniredis.Miniredis, *RedisStore) {
miniRedis, err := miniredis.Run() miniRedis, err := miniredis.Run()
require.Nil(t, err, "MiniRedis mock creation error") require.Nil(t, err, "MiniRedis mock creation error")
@ -37,7 +37,7 @@ func (mrs *mockedRedisStore) Close() error {
return err return err
} }
func newRedisStoreForTesting(t *testing.T) Store { func newRedisStoreForTesting(t testing.TB) Store {
mock, rs := prepareRedisMock(t) mock, rs := prepareRedisMock(t)
mrs := &mockedRedisStore{RedisStore: *rs, mini: mock} mrs := &mockedRedisStore{RedisStore: *rs, mini: mock}
return mrs return mrs

View file

@ -55,7 +55,8 @@ type (
// PutChangeSet allows to push prepared changeset to the Store. // PutChangeSet allows to push prepared changeset to the Store.
PutChangeSet(puts map[string][]byte, dels map[string]bool) error PutChangeSet(puts map[string][]byte, dels map[string]bool) error
// Seek can guarantee that provided key (k) and value (v) are the only valid until the next call to f. // Seek can guarantee that provided key (k) and value (v) are the only valid until the next call to f.
// Key and value slices should not be modified. // Key and value slices should not be modified. Seek can guarantee that key-value items are sorted by
// key in ascending way.
Seek(k []byte, f func(k, v []byte)) Seek(k []byte, f func(k, v []byte))
Close() error Close() error
} }

View file

@ -19,7 +19,7 @@ type kvSeen struct {
type dbSetup struct { type dbSetup struct {
name string name string
create func(*testing.T) Store create func(testing.TB) Store
} }
type dbTestFunction func(*testing.T, Store) type dbTestFunction func(*testing.T, Store)

View file

@ -20,10 +20,11 @@ type Invoke struct {
FaultException string FaultException string
Transaction *transaction.Transaction Transaction *transaction.Transaction
maxIteratorResultItems int maxIteratorResultItems int
finalize func()
} }
// NewInvoke returns new Invoke structure with the given fields set. // NewInvoke returns new Invoke structure with the given fields set.
func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResultItems int) *Invoke { func NewInvoke(vm *vm.VM, finalize func(), script []byte, faultException string, maxIteratorResultItems int) *Invoke {
return &Invoke{ return &Invoke{
State: vm.State().String(), State: vm.State().String(),
GasConsumed: vm.GasConsumed(), GasConsumed: vm.GasConsumed(),
@ -31,6 +32,7 @@ func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResul
Stack: vm.Estack().ToArray(), Stack: vm.Estack().ToArray(),
FaultException: faultException, FaultException: faultException,
maxIteratorResultItems: maxIteratorResultItems, maxIteratorResultItems: maxIteratorResultItems,
finalize: finalize,
} }
} }
@ -55,8 +57,17 @@ type Iterator struct {
Truncated bool Truncated bool
} }
// Finalize releases resources occupied by Iterators created at the script invocation.
// This method will be called automatically on Invoke marshalling.
func (r *Invoke) Finalize() {
if r.finalize != nil {
r.finalize()
}
}
// MarshalJSON implements json.Marshaler. // MarshalJSON implements json.Marshaler.
func (r Invoke) MarshalJSON() ([]byte, error) { func (r Invoke) MarshalJSON() ([]byte, error) {
defer r.Finalize()
var st json.RawMessage var st json.RawMessage
arr := make([]json.RawMessage, len(r.Stack)) arr := make([]json.RawMessage, len(r.Stack))
for i := range arr { for i := range arr {

View file

@ -714,7 +714,7 @@ func TestCreateNEP17TransferTx(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, acc.SignTx(testchain.Network(), tx)) require.NoError(t, acc.SignTx(testchain.Network(), tx))
require.NoError(t, chain.VerifyTx(tx)) require.NoError(t, chain.VerifyTx(tx))
v := chain.GetTestVM(trigger.Application, tx, nil) v, _ := chain.GetTestVM(trigger.Application, tx, nil)
v.LoadScriptWithFlags(tx.Script, callflag.All) v.LoadScriptWithFlags(tx.Script, callflag.All)
require.NoError(t, v.Run()) require.NoError(t, v.Run())
} }

View file

@ -624,6 +624,7 @@ func (s *Server) calculateNetworkFee(reqParams request.Params) (interface{}, *re
if respErr != nil { if respErr != nil {
return 0, respErr return 0, respErr
} }
res.Finalize()
if res.State != "HALT" { if res.State != "HALT" {
cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException) cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException)
return 0, response.NewRPCError(verificationErr, cause.Error(), cause) return 0, response.NewRPCError(verificationErr, cause.Error(), cause)
@ -742,7 +743,8 @@ func (s *Server) getNEP17Balance(h util.Uint160, acc util.Uint160, bw *io.BufBin
} }
script := bw.Bytes() script := bw.Bytes()
tx := &transaction.Transaction{Script: script} tx := &transaction.Transaction{Script: script}
v := s.chain.GetTestVM(trigger.Application, tx, nil) v, finalize := s.chain.GetTestVM(trigger.Application, tx, nil)
defer finalize()
v.GasLimit = core.HeaderVerificationGasLimit v.GasLimit = core.HeaderVerificationGasLimit
v.LoadScriptWithFlags(script, callflag.All) v.LoadScriptWithFlags(script, callflag.All)
err := v.Run() err := v.Run()
@ -1490,7 +1492,6 @@ func (s *Server) invokeContractVerify(reqParams request.Params) (interface{}, *r
tx.Signers = []transaction.Signer{{Account: scriptHash}} tx.Signers = []transaction.Signer{{Account: scriptHash}}
tx.Scripts = []transaction.Witness{{InvocationScript: invocationScript, VerificationScript: []byte{}}} tx.Scripts = []transaction.Witness{{InvocationScript: invocationScript, VerificationScript: []byte{}}}
} }
return s.runScriptInVM(trigger.Verification, invocationScript, scriptHash, tx) return s.runScriptInVM(trigger.Verification, invocationScript, scriptHash, tx)
} }
@ -1511,7 +1512,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
} }
b.Timestamp = hdr.Timestamp + uint64(s.chain.GetConfig().SecondsPerBlock*int(time.Second/time.Millisecond)) b.Timestamp = hdr.Timestamp + uint64(s.chain.GetConfig().SecondsPerBlock*int(time.Second/time.Millisecond))
vm := s.chain.GetTestVM(t, tx, b) vm, finalize := s.chain.GetTestVM(t, tx, b)
vm.GasLimit = int64(s.config.MaxGasInvoke) vm.GasLimit = int64(s.config.MaxGasInvoke)
if t == trigger.Verification { if t == trigger.Verification {
// We need this special case because witnesses verification is not the simple System.Contract.Call, // We need this special case because witnesses verification is not the simple System.Contract.Call,
@ -1539,7 +1540,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
if err != nil { if err != nil {
faultException = err.Error() faultException = err.Error()
} }
return result.NewInvoke(vm, script, faultException, s.config.MaxIteratorResultItems), nil return result.NewInvoke(vm, finalize, script, faultException, s.config.MaxIteratorResultItems), nil
} }
// submitBlock broadcasts a raw block over the NEO network. // submitBlock broadcasts a raw block over the NEO network.

View file

@ -134,16 +134,17 @@ func (o *Oracle) CreateResponseTx(gasForResponse int64, vub uint32, resp *transa
} }
func (o *Oracle) testVerify(tx *transaction.Transaction) (int64, bool) { func (o *Oracle) testVerify(tx *transaction.Transaction) (int64, bool) {
v := o.Chain.GetTestVM(trigger.Verification, tx, nil) v, finalize := o.Chain.GetTestVM(trigger.Verification, tx, nil)
v.GasLimit = o.Chain.GetPolicer().GetMaxVerificationGAS() v.GasLimit = o.Chain.GetPolicer().GetMaxVerificationGAS()
v.LoadScriptWithHash(o.oracleScript, o.oracleHash, callflag.ReadOnly) v.LoadScriptWithHash(o.oracleScript, o.oracleHash, callflag.ReadOnly)
v.Jump(v.Context(), o.verifyOffset) v.Jump(v.Context(), o.verifyOffset)
ok := isVerifyOk(v) ok := isVerifyOk(v, finalize)
return v.GasConsumed(), ok return v.GasConsumed(), ok
} }
func isVerifyOk(v *vm.VM) bool { func isVerifyOk(v *vm.VM, finalize func()) bool {
defer finalize()
if err := v.Run(); err != nil { if err := v.Run(); err != nil {
return false return false
} }