core: move write caching layer into MemCacheStore

Simplify Blockchain and associated functions, deduplicate code, fix Get() and
Seek() implementations.
This commit is contained in:
Roman Khimov 2019-10-16 16:41:50 +03:00
parent 4822c736bb
commit fc0031e5aa
10 changed files with 283 additions and 197 deletions

View file

@ -14,24 +14,18 @@ type Accounts map[util.Uint160]*AccountState
// getAndUpdate retrieves AccountState from temporary or persistent Store // getAndUpdate retrieves AccountState from temporary or persistent Store
// or creates a new one if it doesn't exist. // or creates a new one if it doesn't exist.
func (a Accounts) getAndUpdate(ts storage.Store, ps storage.Store, hash util.Uint160) (*AccountState, error) { func (a Accounts) getAndUpdate(s storage.Store, hash util.Uint160) (*AccountState, error) {
if account, ok := a[hash]; ok { if account, ok := a[hash]; ok {
return account, nil return account, nil
} }
account, err := getAccountStateFromStore(ts, hash) account, err := getAccountStateFromStore(s, hash)
if err != nil {
if err != storage.ErrKeyNotFound {
return nil, err
}
account, err = getAccountStateFromStore(ps, hash)
if err != nil { if err != nil {
if err != storage.ErrKeyNotFound { if err != storage.ErrKeyNotFound {
return nil, err return nil, err
} }
account = NewAccountState(hash) account = NewAccountState(hash)
} }
}
a[hash] = account a[hash] = account
return account, nil return account, nil

View file

@ -43,11 +43,8 @@ var (
type Blockchain struct { type Blockchain struct {
config config.ProtocolConfiguration config config.ProtocolConfiguration
// Any object that satisfies the BlockchainStorer interface. // Persistent storage wrapped around with a write memory caching layer.
storage.Store store *storage.MemCachedStore
// In-memory storage to be persisted into the storage.Store
memStore *storage.MemoryStore
// Current index/height of the highest block. // Current index/height of the highest block.
// Read access should always be called by BlockHeight(). // Read access should always be called by BlockHeight().
@ -78,8 +75,7 @@ type headersOpFunc func(headerList *HeaderHashList)
func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration) (*Blockchain, error) { func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration) (*Blockchain, error) {
bc := &Blockchain{ bc := &Blockchain{
config: cfg, config: cfg,
Store: s, store: storage.NewMemCachedStore(s),
memStore: storage.NewMemoryStore(),
headersOp: make(chan headersOpFunc), headersOp: make(chan headersOpFunc),
headersOpDone: make(chan struct{}), headersOpDone: make(chan struct{}),
memPool: NewMemPool(50000), memPool: NewMemPool(50000),
@ -94,10 +90,10 @@ func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration) (*Blockcha
func (bc *Blockchain) init() error { func (bc *Blockchain) init() error {
// If we could not find the version in the Store, we know that there is nothing stored. // If we could not find the version in the Store, we know that there is nothing stored.
ver, err := storage.Version(bc.Store) ver, err := storage.Version(bc.store)
if err != nil { if err != nil {
log.Infof("no storage version found! creating genesis block") log.Infof("no storage version found! creating genesis block")
if err = storage.PutVersion(bc.Store, version); err != nil { if err = storage.PutVersion(bc.store, version); err != nil {
return err return err
} }
genesisBlock, err := createGenesisBlock(bc.config) genesisBlock, err := createGenesisBlock(bc.config)
@ -116,14 +112,14 @@ func (bc *Blockchain) init() error {
// and the genesis block as first block. // and the genesis block as first block.
log.Infof("restoring blockchain with version: %s", version) log.Infof("restoring blockchain with version: %s", version)
bHeight, err := storage.CurrentBlockHeight(bc.Store) bHeight, err := storage.CurrentBlockHeight(bc.store)
if err != nil { if err != nil {
return err return err
} }
bc.blockHeight = bHeight bc.blockHeight = bHeight
bc.persistedHeight = bHeight bc.persistedHeight = bHeight
hashes, err := storage.HeaderHashes(bc.Store) hashes, err := storage.HeaderHashes(bc.store)
if err != nil { if err != nil {
return err return err
} }
@ -131,7 +127,7 @@ func (bc *Blockchain) init() error {
bc.headerList = NewHeaderHashList(hashes...) bc.headerList = NewHeaderHashList(hashes...)
bc.storedHeaderCount = uint32(len(hashes)) bc.storedHeaderCount = uint32(len(hashes))
currHeaderHeight, currHeaderHash, err := storage.CurrentHeaderHeight(bc.Store) currHeaderHeight, currHeaderHash, err := storage.CurrentHeaderHeight(bc.store)
if err != nil { if err != nil {
return err return err
} }
@ -173,9 +169,7 @@ func (bc *Blockchain) Run(ctx context.Context) {
if err := bc.persist(ctx); err != nil { if err := bc.persist(ctx); err != nil {
log.Warnf("failed to persist: %s", err) log.Warnf("failed to persist: %s", err)
} }
// never fails if err := bc.store.Close(); err != nil {
_ = bc.memStore.Close()
if err := bc.Store.Close(); err != nil {
log.Warnf("failed to close db: %s", err) log.Warnf("failed to close db: %s", err)
} }
}() }()
@ -237,7 +231,7 @@ func (bc *Blockchain) AddBlock(block *Block) error {
func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) { func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {
var ( var (
start = time.Now() start = time.Now()
batch = bc.memStore.Batch() batch = bc.store.Batch()
) )
bc.headersOp <- func(headerList *HeaderHashList) { bc.headersOp <- func(headerList *HeaderHashList) {
@ -263,7 +257,7 @@ func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {
} }
if oldlen != headerList.Len() { if oldlen != headerList.Len() {
if err = bc.memStore.PutBatch(batch); err != nil { if err = bc.store.PutBatch(batch); err != nil {
return return
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -312,7 +306,7 @@ func (bc *Blockchain) processHeader(h *Header, batch storage.Batch, headerList *
// and all tests are in place, we can make a more optimized and cleaner implementation. // and all tests are in place, we can make a more optimized and cleaner implementation.
func (bc *Blockchain) storeBlock(block *Block) error { func (bc *Blockchain) storeBlock(block *Block) error {
var ( var (
batch = bc.memStore.Batch() batch = bc.store.Batch()
unspentCoins = make(UnspentCoins) unspentCoins = make(UnspentCoins)
spentCoins = make(SpentCoins) spentCoins = make(SpentCoins)
accounts = make(Accounts) accounts = make(Accounts)
@ -335,7 +329,7 @@ func (bc *Blockchain) storeBlock(block *Block) error {
// Process TX outputs. // Process TX outputs.
for _, output := range tx.Outputs { for _, output := range tx.Outputs {
account, err := accounts.getAndUpdate(bc.memStore, bc.Store, output.ScriptHash) account, err := accounts.getAndUpdate(bc.store, output.ScriptHash)
if err != nil { if err != nil {
return err return err
} }
@ -353,14 +347,14 @@ func (bc *Blockchain) storeBlock(block *Block) error {
return fmt.Errorf("could not find previous TX: %s", prevHash) return fmt.Errorf("could not find previous TX: %s", prevHash)
} }
for _, input := range inputs { for _, input := range inputs {
unspent, err := unspentCoins.getAndUpdate(bc.memStore, bc.Store, input.PrevHash) unspent, err := unspentCoins.getAndUpdate(bc.store, input.PrevHash)
if err != nil { if err != nil {
return err return err
} }
unspent.states[input.PrevIndex] = CoinStateSpent unspent.states[input.PrevIndex] = CoinStateSpent
prevTXOutput := prevTX.Outputs[input.PrevIndex] prevTXOutput := prevTX.Outputs[input.PrevIndex]
account, err := accounts.getAndUpdate(bc.memStore, bc.Store, prevTXOutput.ScriptHash) account, err := accounts.getAndUpdate(bc.store, prevTXOutput.ScriptHash)
if err != nil { if err != nil {
return err return err
} }
@ -421,13 +415,13 @@ func (bc *Blockchain) storeBlock(block *Block) error {
return cs.Script return cs.Script
}) })
systemInterop := newInteropContext(0x10, bc, block, tx) systemInterop := newInteropContext(0x10, bc, bc.store, block, tx)
vm.RegisterInteropFuncs(systemInterop.getSystemInteropMap()) vm.RegisterInteropFuncs(systemInterop.getSystemInteropMap())
vm.RegisterInteropFuncs(systemInterop.getNeoInteropMap()) vm.RegisterInteropFuncs(systemInterop.getNeoInteropMap())
vm.LoadScript(t.Script) vm.LoadScript(t.Script)
vm.Run() vm.Run()
if !vm.HasFailed() { if !vm.HasFailed() {
_, err := systemInterop.mem.Persist(bc.memStore) _, err := systemInterop.mem.Persist()
if err != nil { if err != nil {
return errors.Wrap(err, "failed to persist invocation results") return errors.Wrap(err, "failed to persist invocation results")
} }
@ -456,7 +450,7 @@ func (bc *Blockchain) storeBlock(block *Block) error {
if err := contracts.commit(batch); err != nil { if err := contracts.commit(batch); err != nil {
return err return err
} }
if err := bc.memStore.PutBatch(batch); err != nil { if err := bc.store.PutBatch(batch); err != nil {
return err return err
} }
@ -472,11 +466,11 @@ func (bc *Blockchain) persist(ctx context.Context) error {
err error err error
) )
persisted, err = bc.memStore.Persist(bc.Store) persisted, err = bc.store.Persist()
if err != nil { if err != nil {
return err return err
} }
bHeight, err := storage.CurrentBlockHeight(bc.Store) bHeight, err := storage.CurrentBlockHeight(bc.store)
if err != nil { if err != nil {
return err return err
} }
@ -510,11 +504,7 @@ func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transactio
if tx, ok := bc.memPool.TryGetValue(hash); ok { if tx, ok := bc.memPool.TryGetValue(hash); ok {
return tx, 0, nil // the height is not actually defined for memPool transaction. Not sure if zero is a good number in this case. return tx, 0, nil // the height is not actually defined for memPool transaction. Not sure if zero is a good number in this case.
} }
tx, height, err := getTransactionFromStore(bc.memStore, hash) return getTransactionFromStore(bc.store, hash)
if err != nil {
tx, height, err = getTransactionFromStore(bc.Store, hash)
}
return tx, height, err
} }
// getTransactionFromStore returns Transaction and its height by the given hash // getTransactionFromStore returns Transaction and its height by the given hash
@ -541,11 +531,7 @@ func getTransactionFromStore(s storage.Store, hash util.Uint256) (*transaction.T
// GetStorageItem returns an item from storage. // GetStorageItem returns an item from storage.
func (bc *Blockchain) GetStorageItem(scripthash util.Uint160, key []byte) *StorageItem { func (bc *Blockchain) GetStorageItem(scripthash util.Uint160, key []byte) *StorageItem {
sItem := getStorageItemFromStore(bc.memStore, scripthash, key) return getStorageItemFromStore(bc.store, scripthash, key)
if sItem == nil {
sItem = getStorageItemFromStore(bc.Store, scripthash, key)
}
return sItem
} }
// GetStorageItems returns all storage items for a given scripthash. // GetStorageItems returns all storage items for a given scripthash.
@ -568,8 +554,7 @@ func (bc *Blockchain) GetStorageItems(hash util.Uint160) (map[string]*StorageIte
// Cut prefix and hash. // Cut prefix and hash.
siMap[string(k[21:])] = si siMap[string(k[21:])] = si
} }
bc.memStore.Seek(storage.AppendPrefix(storage.STStorage, hash.BytesReverse()), saveToMap) bc.store.Seek(storage.AppendPrefix(storage.STStorage, hash.BytesReverse()), saveToMap)
bc.Store.Seek(storage.AppendPrefix(storage.STStorage, hash.BytesReverse()), saveToMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -578,13 +563,10 @@ func (bc *Blockchain) GetStorageItems(hash util.Uint160) (map[string]*StorageIte
// GetBlock returns a Block by the given hash. // GetBlock returns a Block by the given hash.
func (bc *Blockchain) GetBlock(hash util.Uint256) (*Block, error) { func (bc *Blockchain) GetBlock(hash util.Uint256) (*Block, error) {
block, err := getBlockFromStore(bc.memStore, hash) block, err := getBlockFromStore(bc.store, hash)
if err != nil {
block, err = getBlockFromStore(bc.Store, hash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
if len(block.Transactions) == 0 { if len(block.Transactions) == 0 {
return nil, fmt.Errorf("only header is available") return nil, fmt.Errorf("only header is available")
} }
@ -614,14 +596,7 @@ func getBlockFromStore(s storage.Store, hash util.Uint256) (*Block, error) {
// GetHeader returns data block header identified with the given hash value. // GetHeader returns data block header identified with the given hash value.
func (bc *Blockchain) GetHeader(hash util.Uint256) (*Header, error) { func (bc *Blockchain) GetHeader(hash util.Uint256) (*Header, error) {
header, err := getHeaderFromStore(bc.memStore, hash) return getHeaderFromStore(bc.store, hash)
if err != nil {
header, err = getHeaderFromStore(bc.Store, hash)
if err != nil {
return nil, err
}
}
return header, err
} }
// getHeaderFromStore returns Header by the given hash from the store. // getHeaderFromStore returns Header by the given hash from the store.
@ -637,8 +612,7 @@ func getHeaderFromStore(s storage.Store, hash util.Uint256) (*Header, error) {
// transaction hash. // transaction hash.
func (bc *Blockchain) HasTransaction(hash util.Uint256) bool { func (bc *Blockchain) HasTransaction(hash util.Uint256) bool {
return bc.memPool.ContainsKey(hash) || return bc.memPool.ContainsKey(hash) ||
checkTransactionInStore(bc.memStore, hash) || checkTransactionInStore(bc.store, hash)
checkTransactionInStore(bc.Store, hash)
} }
// checkTransactionInStore returns true if the given store contains the given // checkTransactionInStore returns true if the given store contains the given
@ -700,11 +674,7 @@ func (bc *Blockchain) HeaderHeight() uint32 {
// GetAssetState returns asset state from its assetID // GetAssetState returns asset state from its assetID
func (bc *Blockchain) GetAssetState(assetID util.Uint256) *AssetState { func (bc *Blockchain) GetAssetState(assetID util.Uint256) *AssetState {
as := getAssetStateFromStore(bc.memStore, assetID) return getAssetStateFromStore(bc.store, assetID)
if as == nil {
as = getAssetStateFromStore(bc.Store, assetID)
}
return as
} }
// getAssetStateFromStore returns given asset state as recorded in the given // getAssetStateFromStore returns given asset state as recorded in the given
@ -727,11 +697,7 @@ func getAssetStateFromStore(s storage.Store, assetID util.Uint256) *AssetState {
// GetContractState returns contract by its script hash. // GetContractState returns contract by its script hash.
func (bc *Blockchain) GetContractState(hash util.Uint160) *ContractState { func (bc *Blockchain) GetContractState(hash util.Uint160) *ContractState {
cs := getContractStateFromStore(bc.memStore, hash) return getContractStateFromStore(bc.store, hash)
if cs == nil {
cs = getContractStateFromStore(bc.Store, hash)
}
return cs
} }
// getContractStateFromStore returns contract state as recorded in the given // getContractStateFromStore returns contract state as recorded in the given
@ -754,24 +720,18 @@ func getContractStateFromStore(s storage.Store, hash util.Uint160) *ContractStat
// GetAccountState returns the account state from its script hash // GetAccountState returns the account state from its script hash
func (bc *Blockchain) GetAccountState(scriptHash util.Uint160) *AccountState { func (bc *Blockchain) GetAccountState(scriptHash util.Uint160) *AccountState {
as, err := getAccountStateFromStore(bc.memStore, scriptHash) as, err := getAccountStateFromStore(bc.store, scriptHash)
if as == nil {
if err != storage.ErrKeyNotFound {
log.Warnf("failed to get account state: %s", err)
}
as, err = getAccountStateFromStore(bc.Store, scriptHash)
if as == nil && err != storage.ErrKeyNotFound { if as == nil && err != storage.ErrKeyNotFound {
log.Warnf("failed to get account state: %s", err) log.Warnf("failed to get account state: %s", err)
} }
}
return as return as
} }
// GetUnspentCoinState returns unspent coin state for given tx hash. // GetUnspentCoinState returns unspent coin state for given tx hash.
func (bc *Blockchain) GetUnspentCoinState(hash util.Uint256) *UnspentCoinState { func (bc *Blockchain) GetUnspentCoinState(hash util.Uint256) *UnspentCoinState {
ucs, err := getUnspentCoinStateFromStore(bc.memStore, hash) ucs, err := getUnspentCoinStateFromStore(bc.store, hash)
if err != nil { if ucs == nil && err != storage.ErrKeyNotFound {
ucs, _ = getUnspentCoinStateFromStore(bc.Store, hash) log.Warnf("failed to get unspent coin state: %s", err)
} }
return ucs return ucs
} }
@ -872,7 +832,7 @@ func (bc *Blockchain) VerifyTx(t *transaction.Transaction, block *Block) error {
if ok := bc.memPool.Verify(t); !ok { if ok := bc.memPool.Verify(t); !ok {
return errors.New("invalid transaction due to conflicts with the memory pool") return errors.New("invalid transaction due to conflicts with the memory pool")
} }
if IsDoubleSpend(bc.Store, t) { if IsDoubleSpend(bc.store, t) {
return errors.New("invalid transaction caused by double spending") return errors.New("invalid transaction caused by double spending")
} }
if err := bc.verifyOutputs(t); err != nil { if err := bc.verifyOutputs(t); err != nil {
@ -1180,7 +1140,7 @@ func (bc *Blockchain) verifyTxWitnesses(t *transaction.Transaction, block *Block
} }
sort.Slice(hashes, func(i, j int) bool { return hashes[i].Less(hashes[j]) }) sort.Slice(hashes, func(i, j int) bool { return hashes[i].Less(hashes[j]) })
sort.Slice(witnesses, func(i, j int) bool { return witnesses[i].ScriptHash().Less(witnesses[j].ScriptHash()) }) sort.Slice(witnesses, func(i, j int) bool { return witnesses[i].ScriptHash().Less(witnesses[j].ScriptHash()) })
interopCtx := newInteropContext(0, bc, block, t) interopCtx := newInteropContext(0, bc, bc.store, block, t)
for i := 0; i < len(hashes); i++ { for i := 0; i < len(hashes); i++ {
err := bc.verifyHashAgainstScript(hashes[i], witnesses[i], t.VerificationHash(), interopCtx) err := bc.verifyHashAgainstScript(hashes[i], witnesses[i], t.VerificationHash(), interopCtx)
if err != nil { if err != nil {
@ -1200,7 +1160,7 @@ func (bc *Blockchain) verifyBlockWitnesses(block *Block, prevHeader *Header) err
} else { } else {
hash = prevHeader.NextConsensus hash = prevHeader.NextConsensus
} }
interopCtx := newInteropContext(0, bc, nil, nil) interopCtx := newInteropContext(0, bc, bc.store, nil, nil)
return bc.verifyHashAgainstScript(hash, block.Script, block.VerificationHash(), interopCtx) return bc.verifyHashAgainstScript(hash, block.Script, block.VerificationHash(), interopCtx)
} }

View file

@ -57,7 +57,7 @@ func TestAddBlock(t *testing.T) {
for _, block := range blocks { for _, block := range blocks {
key := storage.AppendPrefix(storage.DataBlock, block.Hash().BytesReverse()) key := storage.AppendPrefix(storage.DataBlock, block.Hash().BytesReverse())
if _, err := bc.Get(key); err != nil { if _, err := bc.store.Get(key); err != nil {
t.Fatalf("block %s not persisted", block.Hash()) t.Fatalf("block %s not persisted", block.Hash())
} }
} }

View file

@ -18,11 +18,11 @@ type interopContext struct {
trigger byte trigger byte
block *Block block *Block
tx *transaction.Transaction tx *transaction.Transaction
mem *storage.MemoryStore mem *storage.MemCachedStore
} }
func newInteropContext(trigger byte, bc Blockchainer, block *Block, tx *transaction.Transaction) *interopContext { func newInteropContext(trigger byte, bc Blockchainer, s storage.Store, block *Block, tx *transaction.Transaction) *interopContext {
mem := storage.NewMemoryStore() mem := storage.NewMemCachedStore(s)
return &interopContext{bc, trigger, block, tx, mem} return &interopContext{bc, trigger, block, tx, mem}
} }

View file

@ -0,0 +1,85 @@
package storage
// MemCachedStore is a wrapper around persistent store that caches all changes
// being made for them to be later flushed in one batch.
type MemCachedStore struct {
MemoryStore
// Persistent Store.
ps Store
}
// NewMemCachedStore creates a new MemCachedStore object.
func NewMemCachedStore(lower Store) *MemCachedStore {
return &MemCachedStore{
MemoryStore: *NewMemoryStore(),
ps: lower,
}
}
// Get implements the Store interface.
func (s *MemCachedStore) Get(key []byte) ([]byte, error) {
s.mut.RLock()
defer s.mut.RUnlock()
k := string(key)
if val, ok := s.mem[k]; ok {
return val, nil
}
if _, ok := s.del[k]; ok {
return nil, ErrKeyNotFound
}
return s.ps.Get(key)
}
// Seek implements the Store interface.
func (s *MemCachedStore) Seek(key []byte, f func(k, v []byte)) {
s.mut.RLock()
defer s.mut.RUnlock()
s.MemoryStore.Seek(key, f)
s.ps.Seek(key, func(k, v []byte) {
elem := string(k)
// If it's in mem, we already called f() for it in MemoryStore.Seek().
_, present := s.mem[elem]
if !present {
// If it's in del, we shouldn't be calling f() anyway.
_, present = s.del[elem]
}
if !present {
f(k, v)
}
})
}
// Persist flushes all the MemoryStore contents into the (supposedly) persistent
// store ps.
func (s *MemCachedStore) Persist() (int, error) {
s.mut.Lock()
defer s.mut.Unlock()
batch := s.ps.Batch()
keys, dkeys := 0, 0
for k, v := range s.mem {
batch.Put([]byte(k), v)
keys++
}
for k := range s.del {
batch.Delete([]byte(k))
dkeys++
}
var err error
if keys != 0 || dkeys != 0 {
err = s.ps.PutBatch(batch)
}
if err == nil {
s.mem = make(map[string][]byte)
s.del = make(map[string]bool)
}
return keys, err
}
// Close implements Store interface, clears up memory and closes the lower layer
// Store.
func (s *MemCachedStore) Close() error {
// It's always successful.
_ = s.MemoryStore.Close()
return s.ps.Close()
}

View file

@ -0,0 +1,141 @@
package storage
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMemCachedStorePersist(t *testing.T) {
// persistent Store
ps := NewMemoryStore()
// cached Store
ts := NewMemCachedStore(ps)
// persisting nothing should do nothing
c, err := ts.Persist()
assert.Equal(t, nil, err)
assert.Equal(t, 0, c)
// persisting one key should result in one key in ps and nothing in ts
assert.NoError(t, ts.Put([]byte("key"), []byte("value")))
c, err = ts.Persist()
assert.Equal(t, nil, err)
assert.Equal(t, 1, c)
v, err := ps.Get([]byte("key"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("value"), v)
v, err = ts.MemoryStore.Get([]byte("key"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
// now we overwrite the previous `key` contents and also add `key2`,
assert.NoError(t, ts.Put([]byte("key"), []byte("newvalue")))
assert.NoError(t, ts.Put([]byte("key2"), []byte("value2")))
// this is to check that now key is written into the ps before we do
// persist
v, err = ps.Get([]byte("key2"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
// two keys should be persisted (one overwritten and one new) and
// available in the ps
c, err = ts.Persist()
assert.Equal(t, nil, err)
assert.Equal(t, 2, c)
v, err = ts.MemoryStore.Get([]byte("key"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
v, err = ts.MemoryStore.Get([]byte("key2"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
v, err = ps.Get([]byte("key"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("newvalue"), v)
v, err = ps.Get([]byte("key2"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("value2"), v)
// we've persisted some values, make sure successive persist is a no-op
c, err = ts.Persist()
assert.Equal(t, nil, err)
assert.Equal(t, 0, c)
// test persisting deletions
err = ts.Delete([]byte("key"))
assert.Equal(t, nil, err)
c, err = ts.Persist()
assert.Equal(t, nil, err)
assert.Equal(t, 0, c)
v, err = ps.Get([]byte("key"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
v, err = ps.Get([]byte("key2"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("value2"), v)
}
func TestCachedGetFromPersistent(t *testing.T) {
key := []byte("key")
value := []byte("value")
ps := NewMemoryStore()
ts := NewMemCachedStore(ps)
assert.NoError(t, ps.Put(key, value))
val, err := ts.Get(key)
assert.Nil(t, err)
assert.Equal(t, value, val)
assert.NoError(t, ts.Delete(key))
val, err = ts.Get(key)
assert.Equal(t, err, ErrKeyNotFound)
assert.Nil(t, val)
}
func TestCachedSeek(t *testing.T) {
var (
// Given this prefix...
goodPrefix = []byte{'f'}
// these pairs should be found...
lowerKVs = []kvSeen{
{[]byte("foo"), []byte("bar"), false},
{[]byte("faa"), []byte("bra"), false},
}
// and these should be not.
deletedKVs = []kvSeen{
{[]byte("fee"), []byte("pow"), false},
{[]byte("fii"), []byte("qaz"), false},
}
// and these should be not.
updatedKVs = []kvSeen{
{[]byte("fuu"), []byte("wop"), false},
{[]byte("fyy"), []byte("zaq"), false},
}
ps = NewMemoryStore()
ts = NewMemCachedStore(ps)
)
for _, v := range lowerKVs {
require.NoError(t, ps.Put(v.key, v.val))
}
for _, v := range deletedKVs {
require.NoError(t, ps.Put(v.key, v.val))
require.NoError(t, ts.Delete(v.key))
}
for _, v := range updatedKVs {
require.NoError(t, ps.Put(v.key, []byte("stub")))
require.NoError(t, ts.Put(v.key, v.val))
}
foundKVs := make(map[string][]byte)
ts.Seek(goodPrefix, func(k, v []byte) {
foundKVs[string(k)] = v
})
assert.Equal(t, len(foundKVs), len(lowerKVs)+len(updatedKVs))
for _, kv := range lowerKVs {
assert.Equal(t, kv.val, foundKVs[string(kv.key)])
}
for _, kv := range deletedKVs {
_, ok := foundKVs[string(kv.key)]
assert.Equal(t, false, ok)
}
for _, kv := range updatedKVs {
assert.Equal(t, kv.val, foundKVs[string(kv.key)])
}
}
func newMemCachedStoreForTesting(t *testing.T) Store {
return NewMemCachedStore(NewMemoryStore())
}

View file

@ -114,32 +114,6 @@ func newMemoryBatch() *MemoryBatch {
return &MemoryBatch{MemoryStore: *NewMemoryStore()} return &MemoryBatch{MemoryStore: *NewMemoryStore()}
} }
// Persist flushes all the MemoryStore contents into the (supposedly) persistent
// store provided via parameter.
func (s *MemoryStore) Persist(ps Store) (int, error) {
s.mut.Lock()
defer s.mut.Unlock()
batch := ps.Batch()
keys, dkeys := 0, 0
for k, v := range s.mem {
batch.Put([]byte(k), v)
keys++
}
for k := range s.del {
batch.Delete([]byte(k))
dkeys++
}
var err error
if keys != 0 || dkeys != 0 {
err = ps.PutBatch(batch)
}
if err == nil {
s.mem = make(map[string][]byte)
s.del = make(map[string]bool)
}
return keys, err
}
// Close implements Store interface and clears up memory. Never returns an // Close implements Store interface and clears up memory. Never returns an
// error. // error.
func (s *MemoryStore) Close() error { func (s *MemoryStore) Close() error {

View file

@ -2,73 +2,8 @@ package storage
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestMemoryStorePersist(t *testing.T) {
// temporary Store
ts := NewMemoryStore()
// persistent Store
ps := NewMemoryStore()
// persisting nothing should do nothing
c, err := ts.Persist(ps)
assert.Equal(t, nil, err)
assert.Equal(t, 0, c)
// persisting one key should result in one key in ps and nothing in ts
assert.NoError(t, ts.Put([]byte("key"), []byte("value")))
c, err = ts.Persist(ps)
assert.Equal(t, nil, err)
assert.Equal(t, 1, c)
v, err := ps.Get([]byte("key"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("value"), v)
v, err = ts.Get([]byte("key"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
// now we overwrite the previous `key` contents and also add `key2`,
assert.NoError(t, ts.Put([]byte("key"), []byte("newvalue")))
assert.NoError(t, ts.Put([]byte("key2"), []byte("value2")))
// this is to check that now key is written into the ps before we do
// persist
v, err = ps.Get([]byte("key2"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
// two keys should be persisted (one overwritten and one new) and
// available in the ps
c, err = ts.Persist(ps)
assert.Equal(t, nil, err)
assert.Equal(t, 2, c)
v, err = ts.Get([]byte("key"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
v, err = ts.Get([]byte("key2"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
v, err = ps.Get([]byte("key"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("newvalue"), v)
v, err = ps.Get([]byte("key2"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("value2"), v)
// we've persisted some values, make sure successive persist is a no-op
c, err = ts.Persist(ps)
assert.Equal(t, nil, err)
assert.Equal(t, 0, c)
// test persisting deletions
err = ts.Delete([]byte("key"))
assert.Equal(t, nil, err)
c, err = ts.Persist(ps)
assert.Equal(t, nil, err)
assert.Equal(t, 0, c)
v, err = ps.Get([]byte("key"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
v, err = ps.Get([]byte("key2"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("value2"), v)
}
func newMemoryStoreForTesting(t *testing.T) Store { func newMemoryStoreForTesting(t *testing.T) Store {
return NewMemoryStore() return NewMemoryStore()
} }

View file

@ -9,6 +9,13 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// kvSeen is used to test Seek implementations.
type kvSeen struct {
key []byte
val []byte
seen bool
}
type dbSetup struct { type dbSetup struct {
name string name string
create func(*testing.T) Store create func(*testing.T) Store
@ -66,11 +73,6 @@ func testStorePutBatch(t *testing.T, s Store) {
} }
func testStoreSeek(t *testing.T, s Store) { func testStoreSeek(t *testing.T, s Store) {
type kvSeen struct {
key []byte
val []byte
seen bool
}
var ( var (
// Given this prefix... // Given this prefix...
goodprefix = []byte{'f'} goodprefix = []byte{'f'}
@ -219,6 +221,7 @@ func TestAllDBs(t *testing.T) {
var DBs = []dbSetup{ var DBs = []dbSetup{
{"BoltDB", newBoltStoreForTesting}, {"BoltDB", newBoltStoreForTesting},
{"LevelDB", newLevelDBForTesting}, {"LevelDB", newLevelDBForTesting},
{"MemCached", newMemCachedStoreForTesting},
{"Memory", newMemoryStoreForTesting}, {"Memory", newMemoryStoreForTesting},
{"RedisDB", newRedisStoreForTesting}, {"RedisDB", newRedisStoreForTesting},
} }

View file

@ -16,17 +16,12 @@ type UnspentCoins map[util.Uint256]*UnspentCoinState
// getAndUpdate retreives UnspentCoinState from temporary or persistent Store // getAndUpdate retreives UnspentCoinState from temporary or persistent Store
// and return it. If it's not present in both stores, returns a new // and return it. If it's not present in both stores, returns a new
// UnspentCoinState. // UnspentCoinState.
func (u UnspentCoins) getAndUpdate(ts storage.Store, ps storage.Store, hash util.Uint256) (*UnspentCoinState, error) { func (u UnspentCoins) getAndUpdate(s storage.Store, hash util.Uint256) (*UnspentCoinState, error) {
if unspent, ok := u[hash]; ok { if unspent, ok := u[hash]; ok {
return unspent, nil return unspent, nil
} }
unspent, err := getUnspentCoinStateFromStore(ts, hash) unspent, err := getUnspentCoinStateFromStore(s, hash)
if err != nil {
if err != storage.ErrKeyNotFound {
return nil, err
}
unspent, err = getUnspentCoinStateFromStore(ps, hash)
if err != nil { if err != nil {
if err != storage.ErrKeyNotFound { if err != storage.ErrKeyNotFound {
return nil, err return nil, err
@ -35,7 +30,6 @@ func (u UnspentCoins) getAndUpdate(ts storage.Store, ps storage.Store, hash util
states: []CoinState{}, states: []CoinState{},
} }
} }
}
u[hash] = unspent u[hash] = unspent
return unspent, nil return unspent, nil