Merge pull request #414 from nspcc-dev/persistence-rewamp

What started as an attempt to fix #366 ended up being quite substantial refactoring of the Blockchain->Store and Server->Blockchain interactions. As usually, some additional problems were noted and fixed along the way. It also accidentally fixes #410.
This commit is contained in:
Roman Khimov 2019-09-27 17:55:43 +03:00 committed by GitHub
commit fac778d3dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 545 additions and 399 deletions

1
go.mod
View file

@ -1,6 +1,7 @@
module github.com/CityOfZion/neo-go module github.com/CityOfZion/neo-go
require ( require (
github.com/Workiva/go-datastructures v1.0.50
github.com/abiosoft/ishell v2.0.0+incompatible // indirect github.com/abiosoft/ishell v2.0.0+incompatible // indirect
github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db // indirect github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db // indirect
github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 // indirect github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 // indirect

2
go.sum
View file

@ -1,3 +1,5 @@
github.com/Workiva/go-datastructures v1.0.50 h1:slDmfW6KCHcC7U+LP3DDBbm4fqTwZGn1beOFPfGaLvo=
github.com/Workiva/go-datastructures v1.0.50/go.mod h1:Z+F2Rca0qCsVYDS8z7bAGm8f3UkzuWYS/oBZz5a7VVA=
github.com/abiosoft/ishell v2.0.0+incompatible h1:zpwIuEHc37EzrsIYah3cpevrIc8Oma7oZPxr03tlmmw= github.com/abiosoft/ishell v2.0.0+incompatible h1:zpwIuEHc37EzrsIYah3cpevrIc8Oma7oZPxr03tlmmw=
github.com/abiosoft/ishell v2.0.0+incompatible/go.mod h1:HQR9AqF2R3P4XXpMpI0NAzgHf/aS6+zVXRj14cVk9qg= github.com/abiosoft/ishell v2.0.0+incompatible/go.mod h1:HQR9AqF2R3P4XXpMpI0NAzgHf/aS6+zVXRj14cVk9qg=
github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db h1:CjPUSXOiYptLbTdr1RceuZgSFDQ7U15ITERUGrUORx8= github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db h1:CjPUSXOiYptLbTdr1RceuZgSFDQ7U15ITERUGrUORx8=

View file

@ -12,25 +12,46 @@ import (
// Accounts is mapping between a account address and AccountState. // Accounts is mapping between a account address and AccountState.
type Accounts map[util.Uint160]*AccountState type Accounts map[util.Uint160]*AccountState
func (a Accounts) getAndUpdate(s storage.Store, hash util.Uint160) (*AccountState, error) { // getAndUpdate retrieves AccountState from temporary or persistent Store
// or creates a new one if it doesn't exist.
func (a Accounts) getAndUpdate(ts storage.Store, ps storage.Store, hash util.Uint160) (*AccountState, error) {
if account, ok := a[hash]; ok { if account, ok := a[hash]; ok {
return account, nil return account, nil
} }
account := &AccountState{} account, err := getAccountStateFromStore(ts, hash)
if err != nil {
if err != storage.ErrKeyNotFound {
return nil, err
}
account, err = getAccountStateFromStore(ps, hash)
if err != nil {
if err != storage.ErrKeyNotFound {
return nil, err
}
account = NewAccountState(hash)
}
}
a[hash] = account
return account, nil
}
// getAccountStateFromStore returns AccountState from the given Store if it's
// present there. Returns nil otherwise.
func getAccountStateFromStore(s storage.Store, hash util.Uint160) (*AccountState, error) {
var account *AccountState
key := storage.AppendPrefix(storage.STAccount, hash.Bytes()) key := storage.AppendPrefix(storage.STAccount, hash.Bytes())
if b, err := s.Get(key); err == nil { b, err := s.Get(key)
if err == nil {
account = new(AccountState)
r := io.NewBinReaderFromBuf(b) r := io.NewBinReaderFromBuf(b)
account.DecodeBinary(r) account.DecodeBinary(r)
if r.Err != nil { if r.Err != nil {
return nil, fmt.Errorf("failed to decode (AccountState): %s", r.Err) return nil, fmt.Errorf("failed to decode (AccountState): %s", r.Err)
} }
} else {
account = NewAccountState(hash)
} }
return account, err
a[hash] = account
return account, nil
} }
// commit writes all account states to the given Batch. // commit writes all account states to the given Batch.

View file

@ -5,6 +5,7 @@ import (
"github.com/CityOfZion/neo-go/pkg/crypto" "github.com/CityOfZion/neo-go/pkg/crypto"
"github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/io"
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
"github.com/Workiva/go-datastructures/queue"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -132,3 +133,16 @@ func (b *Block) EncodeBinary(bw *io.BinWriter) {
tx.EncodeBinary(bw) tx.EncodeBinary(bw)
} }
} }
// Compare implements the queue Item interface.
func (b *Block) Compare(item queue.Item) int {
other := item.(*Block)
switch {
case b.Index > other.Index:
return 1
case b.Index == other.Index:
return 0
default:
return -1
}
}

View file

@ -259,3 +259,12 @@ func TestBlockSizeCalculation(t *testing.T) {
assert.Equal(t, 7360, len(benc)) assert.Equal(t, 7360, len(benc))
assert.Equal(t, rawBlock, hex.EncodeToString(benc)) assert.Equal(t, rawBlock, hex.EncodeToString(benc))
} }
func TestBlockCompare(t *testing.T) {
b1 := Block{BlockBase: BlockBase{Index: 1}}
b2 := Block{BlockBase: BlockBase{Index: 2}}
b3 := Block{BlockBase: BlockBase{Index: 3}}
assert.Equal(t, 1, b2.Compare(&b1))
assert.Equal(t, 0, b2.Compare(&b2))
assert.Equal(t, -1, b2.Compare(&b3))
}

View file

@ -37,16 +37,20 @@ type Blockchain struct {
// Any object that satisfies the BlockchainStorer interface. // Any object that satisfies the BlockchainStorer interface.
storage.Store storage.Store
// In-memory storage to be persisted into the storage.Store
memStore *storage.MemoryStore
// Current index/height of the highest block. // Current index/height of the highest block.
// Read access should always be called by BlockHeight(). // Read access should always be called by BlockHeight().
// Write access should only happen in Persist(). // Write access should only happen in storeBlock().
blockHeight uint32 blockHeight uint32
// Current persisted block count.
persistedHeight uint32
// Number of headers stored in the chain file. // Number of headers stored in the chain file.
storedHeaderCount uint32 storedHeaderCount uint32
blockCache *Cache
// All operation on headerList must be called from an // All operation on headerList must be called from an
// headersOp to be routine safe. // headersOp to be routine safe.
headerList *HeaderHashList headerList *HeaderHashList
@ -69,9 +73,9 @@ func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration) (*Blockcha
bc := &Blockchain{ bc := &Blockchain{
config: cfg, config: cfg,
Store: s, Store: s,
memStore: storage.NewMemoryStore(),
headersOp: make(chan headersOpFunc), headersOp: make(chan headersOpFunc),
headersOpDone: make(chan struct{}), headersOpDone: make(chan struct{}),
blockCache: NewCache(),
verifyBlocks: false, verifyBlocks: false,
memPool: NewMemPool(50000), memPool: NewMemPool(50000),
} }
@ -96,7 +100,7 @@ func (bc *Blockchain) init() error {
return err return err
} }
bc.headerList = NewHeaderHashList(genesisBlock.Hash()) bc.headerList = NewHeaderHashList(genesisBlock.Hash())
return bc.persistBlock(genesisBlock) return bc.storeBlock(genesisBlock)
} }
if ver != version { if ver != version {
return fmt.Errorf("storage version mismatch betweeen %s and %s", version, ver) return fmt.Errorf("storage version mismatch betweeen %s and %s", version, ver)
@ -112,6 +116,7 @@ func (bc *Blockchain) init() error {
return err return err
} }
bc.blockHeight = bHeight bc.blockHeight = bHeight
bc.persistedHeight = bHeight
hashes, err := storage.HeaderHashes(bc.Store) hashes, err := storage.HeaderHashes(bc.Store)
if err != nil { if err != nil {
@ -144,8 +149,11 @@ func (bc *Blockchain) init() error {
} }
headerSliceReverse(headers) headerSliceReverse(headers)
if err := bc.AddHeaders(headers...); err != nil { for _, h := range headers {
return err if !h.Verify() {
return fmt.Errorf("bad header %d/%s in the storage", h.Index, h.Hash())
}
bc.headerList.Add(h.Hash())
} }
} }
@ -157,6 +165,11 @@ func (bc *Blockchain) Run(ctx context.Context) {
persistTimer := time.NewTimer(persistInterval) persistTimer := time.NewTimer(persistInterval)
defer func() { defer func() {
persistTimer.Stop() persistTimer.Stop()
if err := bc.persist(ctx); err != nil {
log.Warnf("failed to persist: %s", err)
}
// never fails
_ = bc.memStore.Close()
if err := bc.Store.Close(); err != nil { if err := bc.Store.Close(); err != nil {
log.Warnf("failed to close db: %s", err) log.Warnf("failed to close db: %s", err)
} }
@ -170,7 +183,7 @@ func (bc *Blockchain) Run(ctx context.Context) {
bc.headersOpDone <- struct{}{} bc.headersOpDone <- struct{}{}
case <-persistTimer.C: case <-persistTimer.C:
go func() { go func() {
err := bc.Persist(ctx) err := bc.persist(ctx)
if err != nil { if err != nil {
log.Warnf("failed to persist blockchain: %s", err) log.Warnf("failed to persist blockchain: %s", err)
} }
@ -180,24 +193,24 @@ func (bc *Blockchain) Run(ctx context.Context) {
} }
} }
// AddBlock processes the given block and will add it to the cache so it // AddBlock accepts successive block for the Blockchain, verifies it and
// can be persisted. // stores internally. Eventually it will be persisted to the backing storage.
func (bc *Blockchain) AddBlock(block *Block) error { func (bc *Blockchain) AddBlock(block *Block) error {
if !bc.blockCache.Has(block.Hash()) { expectedHeight := bc.BlockHeight() + 1
bc.blockCache.Add(block.Hash(), block) if expectedHeight != block.Index {
return fmt.Errorf("expected block %d, but passed block %d", expectedHeight, block.Index)
} }
headerLen := bc.headerListLen()
if int(block.Index-1) >= headerLen {
return nil
}
if int(block.Index) == headerLen {
if bc.verifyBlocks && !block.Verify(false) { if bc.verifyBlocks && !block.Verify(false) {
return fmt.Errorf("block %s is invalid", block.Hash()) return fmt.Errorf("block %s is invalid", block.Hash())
} }
return bc.AddHeaders(block.Header()) headerLen := bc.headerListLen()
if int(block.Index) == headerLen {
err := bc.AddHeaders(block.Header())
if err != nil {
return err
} }
return nil }
return bc.storeBlock(block)
} }
// AddHeaders will process the given headers and add them to the // AddHeaders will process the given headers and add them to the
@ -205,7 +218,7 @@ func (bc *Blockchain) AddBlock(block *Block) error {
func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) { func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {
var ( var (
start = time.Now() start = time.Now()
batch = bc.Batch() batch = bc.memStore.Batch()
) )
bc.headersOp <- func(headerList *HeaderHashList) { bc.headersOp <- func(headerList *HeaderHashList) {
@ -230,7 +243,7 @@ func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {
} }
if batch.Len() > 0 { if batch.Len() > 0 {
if err = bc.PutBatch(batch); err != nil { if err = bc.memStore.PutBatch(batch); err != nil {
return return
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -273,13 +286,13 @@ func (bc *Blockchain) processHeader(h *Header, batch storage.Batch, headerList *
return nil return nil
} }
// TODO: persistBlock needs some more love, its implemented as in the original // TODO: storeBlock needs some more love, its implemented as in the original
// project. This for the sake of development speed and understanding of what // project. This for the sake of development speed and understanding of what
// is happening here, quite allot as you can see :). If things are wired together // is happening here, quite allot as you can see :). If things are wired together
// and all tests are in place, we can make a more optimized and cleaner implementation. // and all tests are in place, we can make a more optimized and cleaner implementation.
func (bc *Blockchain) persistBlock(block *Block) error { func (bc *Blockchain) storeBlock(block *Block) error {
var ( var (
batch = bc.Batch() batch = bc.memStore.Batch()
unspentCoins = make(UnspentCoins) unspentCoins = make(UnspentCoins)
spentCoins = make(SpentCoins) spentCoins = make(SpentCoins)
accounts = make(Accounts) accounts = make(Accounts)
@ -301,7 +314,7 @@ func (bc *Blockchain) persistBlock(block *Block) error {
// Process TX outputs. // Process TX outputs.
for _, output := range tx.Outputs { for _, output := range tx.Outputs {
account, err := accounts.getAndUpdate(bc.Store, output.ScriptHash) account, err := accounts.getAndUpdate(bc.memStore, bc.Store, output.ScriptHash)
if err != nil { if err != nil {
return err return err
} }
@ -319,14 +332,14 @@ func (bc *Blockchain) persistBlock(block *Block) error {
return fmt.Errorf("could not find previous TX: %s", prevHash) return fmt.Errorf("could not find previous TX: %s", prevHash)
} }
for _, input := range inputs { for _, input := range inputs {
unspent, err := unspentCoins.getAndUpdate(bc.Store, input.PrevHash) unspent, err := unspentCoins.getAndUpdate(bc.memStore, bc.Store, input.PrevHash)
if err != nil { if err != nil {
return err return err
} }
unspent.states[input.PrevIndex] = CoinStateSpent unspent.states[input.PrevIndex] = CoinStateSpent
prevTXOutput := prevTX.Outputs[input.PrevIndex] prevTXOutput := prevTX.Outputs[input.PrevIndex]
account, err := accounts.getAndUpdate(bc.Store, prevTXOutput.ScriptHash) account, err := accounts.getAndUpdate(bc.memStore, bc.Store, prevTXOutput.ScriptHash)
if err != nil { if err != nil {
return err return err
} }
@ -388,7 +401,7 @@ func (bc *Blockchain) persistBlock(block *Block) error {
if err := assets.commit(batch); err != nil { if err := assets.commit(batch); err != nil {
return err return err
} }
if err := bc.PutBatch(batch); err != nil { if err := bc.memStore.PutBatch(batch); err != nil {
return err return err
} }
@ -396,63 +409,37 @@ func (bc *Blockchain) persistBlock(block *Block) error {
return nil return nil
} }
//Persist starts persist loop. // persist flushes current in-memory store contents to the persistent storage.
func (bc *Blockchain) Persist(ctx context.Context) (err error) { func (bc *Blockchain) persist(ctx context.Context) error {
var ( var (
start = time.Now() start = time.Now()
persisted = 0 persisted = 0
lenCache = bc.blockCache.Len() err error
) )
if lenCache == 0 { persisted, err = bc.memStore.Persist(bc.Store)
return nil if err != nil {
return err
} }
bHeight, err := storage.CurrentBlockHeight(bc.Store)
bc.headersOp <- func(headerList *HeaderHashList) { if err != nil {
for i := 0; i < lenCache; i++ { return err
if uint32(headerList.Len()) <= bc.BlockHeight() {
return
}
hash := headerList.Get(int(bc.BlockHeight() + 1))
if block, ok := bc.blockCache.GetBlock(hash); ok {
if err = bc.persistBlock(block); err != nil {
return
}
bc.blockCache.Delete(hash)
persisted++
} else {
// no next block in the cache, no reason to continue looping
break
}
}
}
select {
case <-ctx.Done():
return
case <-bc.headersOpDone:
//
} }
oldHeight := atomic.SwapUint32(&bc.persistedHeight, bHeight)
diff := bHeight - oldHeight
if persisted > 0 { if persisted > 0 {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"persisted": persisted, "persistedBlocks": diff,
"persistedKeys": persisted,
"headerHeight": bc.HeaderHeight(), "headerHeight": bc.HeaderHeight(),
"blockHeight": bc.BlockHeight(), "blockHeight": bc.BlockHeight(),
"persistedHeight": bc.persistedHeight,
"took": time.Since(start), "took": time.Since(start),
}).Info("blockchain persist completed") }).Info("blockchain persist completed")
} else {
// So we have some blocks in cache but can't persist them?
// Either there are some stale blocks there or the other way
// around (which was seen in practice) --- there are some fresh
// blocks that we can't persist yet. Some of the latter can be useful
// or can be bogus (higher than the header height we expect at
// the moment). So try to reap oldies and strange newbies, if
// there are any.
bc.blockCache.ReapStrangeBlocks(bc.BlockHeight(), bc.HeaderHeight())
} }
return return nil
} }
func (bc *Blockchain) headerListLen() (n int) { func (bc *Blockchain) headerListLen() (n int) {
@ -468,9 +455,18 @@ func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transactio
if tx, ok := bc.memPool.TryGetValue(hash); ok { if tx, ok := bc.memPool.TryGetValue(hash); ok {
return tx, 0, nil // the height is not actually defined for memPool transaction. Not sure if zero is a good number in this case. return tx, 0, nil // the height is not actually defined for memPool transaction. Not sure if zero is a good number in this case.
} }
tx, height, err := getTransactionFromStore(bc.memStore, hash)
if err != nil {
tx, height, err = getTransactionFromStore(bc.Store, hash)
}
return tx, height, err
}
// getTransactionFromStore returns Transaction and its height by the given hash
// if it exists in the store.
func getTransactionFromStore(s storage.Store, hash util.Uint256) (*transaction.Transaction, uint32, error) {
key := storage.AppendPrefix(storage.DataTransaction, hash.BytesReverse()) key := storage.AppendPrefix(storage.DataTransaction, hash.BytesReverse())
b, err := bc.Get(key) b, err := s.Get(key)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -490,8 +486,23 @@ func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transactio
// GetBlock returns a Block by the given hash. // GetBlock returns a Block by the given hash.
func (bc *Blockchain) GetBlock(hash util.Uint256) (*Block, error) { func (bc *Blockchain) GetBlock(hash util.Uint256) (*Block, error) {
block, err := getBlockFromStore(bc.memStore, hash)
if err != nil {
block, err = getBlockFromStore(bc.Store, hash)
if err != nil {
return nil, err
}
}
if len(block.Transactions) == 0 {
return nil, fmt.Errorf("only header is available")
}
return block, nil
}
// getBlockFromStore returns Block by the given hash if it exists in the store.
func getBlockFromStore(s storage.Store, hash util.Uint256) (*Block, error) {
key := storage.AppendPrefix(storage.DataBlock, hash.BytesReverse()) key := storage.AppendPrefix(storage.DataBlock, hash.BytesReverse())
b, err := bc.Get(key) b, err := s.Get(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -499,20 +510,24 @@ func (bc *Blockchain) GetBlock(hash util.Uint256) (*Block, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: persist TX first before we can handle this logic. return block, err
// if len(block.Transactions) == 0 {
// return nil, fmt.Errorf("block has no TX")
// }
return block, nil
} }
// GetHeader returns data block header identified with the given hash value. // GetHeader returns data block header identified with the given hash value.
func (bc *Blockchain) GetHeader(hash util.Uint256) (*Header, error) { func (bc *Blockchain) GetHeader(hash util.Uint256) (*Header, error) {
b, err := bc.Get(storage.AppendPrefix(storage.DataBlock, hash.BytesReverse())) header, err := getHeaderFromStore(bc.memStore, hash)
if err != nil {
header, err = getHeaderFromStore(bc.Store, hash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
block, err := NewBlockFromTrimmedBytes(b) }
return header, err
}
// getHeaderFromStore returns Header by the given hash from the store.
func getHeaderFromStore(s storage.Store, hash util.Uint256) (*Header, error) {
block, err := getBlockFromStore(s, hash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -522,12 +537,16 @@ func (bc *Blockchain) GetHeader(hash util.Uint256) (*Header, error) {
// HasTransaction return true if the blockchain contains he given // HasTransaction return true if the blockchain contains he given
// transaction hash. // transaction hash.
func (bc *Blockchain) HasTransaction(hash util.Uint256) bool { func (bc *Blockchain) HasTransaction(hash util.Uint256) bool {
if bc.memPool.ContainsKey(hash) { return bc.memPool.ContainsKey(hash) ||
return true checkTransactionInStore(bc.memStore, hash) ||
checkTransactionInStore(bc.Store, hash)
} }
// checkTransactionInStore returns true if the given store contains the given
// Transaction hash.
func checkTransactionInStore(s storage.Store, hash util.Uint256) bool {
key := storage.AppendPrefix(storage.DataTransaction, hash.BytesReverse()) key := storage.AppendPrefix(storage.DataTransaction, hash.BytesReverse())
if _, err := bc.Get(key); err == nil { if _, err := s.Get(key); err == nil {
return true return true
} }
return false return false
@ -582,31 +601,43 @@ func (bc *Blockchain) HeaderHeight() uint32 {
// GetAssetState returns asset state from its assetID // GetAssetState returns asset state from its assetID
func (bc *Blockchain) GetAssetState(assetID util.Uint256) *AssetState { func (bc *Blockchain) GetAssetState(assetID util.Uint256) *AssetState {
var as *AssetState as := getAssetStateFromStore(bc.memStore, assetID)
bc.Store.Seek(storage.STAsset.Bytes(), func(k, v []byte) { if as == nil {
var a AssetState as = getAssetStateFromStore(bc.Store, assetID)
r := io.NewBinReaderFromBuf(v)
a.DecodeBinary(r)
if r.Err == nil && a.ID == assetID {
as = &a
} }
})
return as return as
} }
// getAssetStateFromStore returns given asset state as recorded in the given
// store.
func getAssetStateFromStore(s storage.Store, assetID util.Uint256) *AssetState {
key := storage.AppendPrefix(storage.STAsset, assetID.Bytes())
asEncoded, err := s.Get(key)
if err != nil {
return nil
}
var a AssetState
r := io.NewBinReaderFromBuf(asEncoded)
a.DecodeBinary(r)
if r.Err != nil || a.ID != assetID {
return nil
}
return &a
}
// GetAccountState returns the account state from its script hash // GetAccountState returns the account state from its script hash
func (bc *Blockchain) GetAccountState(scriptHash util.Uint160) *AccountState { func (bc *Blockchain) GetAccountState(scriptHash util.Uint160) *AccountState {
var as *AccountState as, err := getAccountStateFromStore(bc.memStore, scriptHash)
bc.Store.Seek(storage.STAccount.Bytes(), func(k, v []byte) { if as == nil {
var a AccountState if err != storage.ErrKeyNotFound {
r := io.NewBinReaderFromBuf(v) log.Warnf("failed to get account state: %s", err)
a.DecodeBinary(r) }
if r.Err == nil && a.ScriptHash == scriptHash { as, err = getAccountStateFromStore(bc.Store, scriptHash)
as = &a if as == nil && err != storage.ErrKeyNotFound {
log.Warnf("failed to get account state: %s", err)
}
} }
})
return as return as
} }

View file

@ -13,9 +13,6 @@ import (
func TestAddHeaders(t *testing.T) { func TestAddHeaders(t *testing.T) {
bc := newTestChain(t) bc := newTestChain(t)
defer func() {
require.NoError(t, bc.Close())
}()
h1 := newBlock(1).Header() h1 := newBlock(1).Header()
h2 := newBlock(2).Header() h2 := newBlock(2).Header()
h3 := newBlock(3).Header() h3 := newBlock(3).Header()
@ -24,7 +21,6 @@ func TestAddHeaders(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, 0, bc.blockCache.Len())
assert.Equal(t, h3.Index, bc.HeaderHeight()) assert.Equal(t, h3.Index, bc.HeaderHeight())
assert.Equal(t, uint32(0), bc.BlockHeight()) assert.Equal(t, uint32(0), bc.BlockHeight())
assert.Equal(t, h3.Hash(), bc.CurrentHeaderHash()) assert.Equal(t, h3.Hash(), bc.CurrentHeaderHash())
@ -41,9 +37,6 @@ func TestAddHeaders(t *testing.T) {
func TestAddBlock(t *testing.T) { func TestAddBlock(t *testing.T) {
bc := newTestChain(t) bc := newTestChain(t)
defer func() {
require.NoError(t, bc.Close())
}()
blocks := []*Block{ blocks := []*Block{
newBlock(1), newBlock(1),
newBlock(2), newBlock(2),
@ -57,15 +50,11 @@ func TestAddBlock(t *testing.T) {
} }
lastBlock := blocks[len(blocks)-1] lastBlock := blocks[len(blocks)-1]
assert.Equal(t, 3, bc.blockCache.Len())
assert.Equal(t, lastBlock.Index, bc.HeaderHeight()) assert.Equal(t, lastBlock.Index, bc.HeaderHeight())
assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash()) assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash())
t.Log(bc.blockCache) // This one tests persisting blocks, so it does need to persist()
require.NoError(t, bc.persist(context.Background()))
if err := bc.Persist(context.Background()); err != nil {
t.Fatal(err)
}
for _, block := range blocks { for _, block := range blocks {
key := storage.AppendPrefix(storage.DataBlock, block.Hash().BytesReverse()) key := storage.AppendPrefix(storage.DataBlock, block.Hash().BytesReverse())
@ -76,33 +65,30 @@ func TestAddBlock(t *testing.T) {
assert.Equal(t, lastBlock.Index, bc.BlockHeight()) assert.Equal(t, lastBlock.Index, bc.BlockHeight())
assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash()) assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash())
assert.Equal(t, 0, bc.blockCache.Len())
} }
func TestGetHeader(t *testing.T) { func TestGetHeader(t *testing.T) {
bc := newTestChain(t) bc := newTestChain(t)
defer func() {
require.NoError(t, bc.Close())
}()
block := newBlock(1) block := newBlock(1)
err := bc.AddBlock(block) err := bc.AddBlock(block)
assert.Nil(t, err) assert.Nil(t, err)
// Test unpersisted and persisted access
for i := 0; i < 2; i++ {
hash := block.Hash() hash := block.Hash()
header, err := bc.GetHeader(hash) header, err := bc.GetHeader(hash)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, block.Header(), header) assert.Equal(t, block.Header(), header)
block = newBlock(2) b2 := newBlock(2)
_, err = bc.GetHeader(block.Hash()) _, err = bc.GetHeader(b2.Hash())
assert.Error(t, err) assert.Error(t, err)
assert.NoError(t, bc.persist(context.Background()))
}
} }
func TestGetBlock(t *testing.T) { func TestGetBlock(t *testing.T) {
bc := newTestChain(t) bc := newTestChain(t)
defer func() {
require.NoError(t, bc.Close())
}()
blocks := makeBlocks(100) blocks := makeBlocks(100)
for i := 0; i < len(blocks); i++ { for i := 0; i < len(blocks); i++ {
@ -111,6 +97,8 @@ func TestGetBlock(t *testing.T) {
} }
} }
// Test unpersisted and persisted access
for j := 0; j < 2; j++ {
for i := 0; i < len(blocks); i++ { for i := 0; i < len(blocks); i++ {
block, err := bc.GetBlock(blocks[i].Hash()) block, err := bc.GetBlock(blocks[i].Hash())
if err != nil { if err != nil {
@ -119,13 +107,12 @@ func TestGetBlock(t *testing.T) {
assert.Equal(t, blocks[i].Index, block.Index) assert.Equal(t, blocks[i].Index, block.Index)
assert.Equal(t, blocks[i].Hash(), block.Hash()) assert.Equal(t, blocks[i].Hash(), block.Hash())
} }
assert.NoError(t, bc.persist(context.Background()))
}
} }
func TestHasBlock(t *testing.T) { func TestHasBlock(t *testing.T) {
bc := newTestChain(t) bc := newTestChain(t)
defer func() {
require.NoError(t, bc.Close())
}()
blocks := makeBlocks(50) blocks := makeBlocks(50)
for i := 0; i < len(blocks); i++ { for i := 0; i < len(blocks); i++ {
@ -133,30 +120,30 @@ func TestHasBlock(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
assert.Nil(t, bc.Persist(context.Background()))
// Test unpersisted and persisted access
for j := 0; j < 2; j++ {
for i := 0; i < len(blocks); i++ { for i := 0; i < len(blocks); i++ {
assert.True(t, bc.HasBlock(blocks[i].Hash())) assert.True(t, bc.HasBlock(blocks[i].Hash()))
} }
newBlock := newBlock(51) newBlock := newBlock(51)
assert.False(t, bc.HasBlock(newBlock.Hash())) assert.False(t, bc.HasBlock(newBlock.Hash()))
assert.NoError(t, bc.persist(context.Background()))
}
} }
func TestGetTransaction(t *testing.T) { func TestGetTransaction(t *testing.T) {
b1 := getDecodedBlock(t, 1)
block := getDecodedBlock(t, 2) block := getDecodedBlock(t, 2)
bc := newTestChain(t) bc := newTestChain(t)
defer func() {
require.NoError(t, bc.Close())
}()
assert.Nil(t, bc.AddBlock(b1))
assert.Nil(t, bc.AddBlock(block)) assert.Nil(t, bc.AddBlock(block))
assert.Nil(t, bc.persistBlock(block))
// Test unpersisted and persisted access
for j := 0; j < 2; j++ {
tx, height, err := bc.GetTransaction(block.Transactions[0].Hash()) tx, height, err := bc.GetTransaction(block.Transactions[0].Hash())
if err != nil { require.Nil(t, err)
t.Fatal(err)
}
assert.Equal(t, block.Index, height) assert.Equal(t, block.Index, height)
assert.Equal(t, block.Transactions[0], tx) assert.Equal(t, block.Transactions[0], tx)
assert.Equal(t, 10, io.GetVarSize(tx)) assert.Equal(t, 10, io.GetVarSize(tx))
@ -164,6 +151,8 @@ func TestGetTransaction(t *testing.T) {
assert.Equal(t, 1, io.GetVarSize(tx.Inputs)) assert.Equal(t, 1, io.GetVarSize(tx.Inputs))
assert.Equal(t, 1, io.GetVarSize(tx.Outputs)) assert.Equal(t, 1, io.GetVarSize(tx.Outputs))
assert.Equal(t, 1, io.GetVarSize(tx.Scripts)) assert.Equal(t, 1, io.GetVarSize(tx.Scripts))
assert.NoError(t, bc.persist(context.Background()))
}
} }
func newTestChain(t *testing.T) *Blockchain { func newTestChain(t *testing.T) *Blockchain {

View file

@ -1,87 +0,0 @@
package core
import (
"sync"
"github.com/CityOfZion/neo-go/pkg/util"
)
// Cache is data structure with fixed type key of Uint256, but has a
// generic value. Used for block, tx and header cache types.
type Cache struct {
lock sync.RWMutex
m map[util.Uint256]interface{}
}
// NewCache returns a ready to use Cache object.
func NewCache() *Cache {
return &Cache{
m: make(map[util.Uint256]interface{}),
}
}
// GetBlock will return a Block type from the cache.
func (c *Cache) GetBlock(h util.Uint256) (block *Block, ok bool) {
c.lock.RLock()
defer c.lock.RUnlock()
return c.getBlock(h)
}
func (c *Cache) getBlock(h util.Uint256) (block *Block, ok bool) {
if v, b := c.m[h]; b {
block, ok = v.(*Block)
return
}
return
}
// Add adds the given hash along with its value to the cache.
func (c *Cache) Add(h util.Uint256, v interface{}) {
c.lock.Lock()
defer c.lock.Unlock()
c.add(h, v)
}
func (c *Cache) add(h util.Uint256, v interface{}) {
c.m[h] = v
}
func (c *Cache) has(h util.Uint256) bool {
_, ok := c.m[h]
return ok
}
// Has returns whether the cache contains the given hash.
func (c *Cache) Has(h util.Uint256) bool {
c.lock.Lock()
defer c.lock.Unlock()
return c.has(h)
}
// Len return the number of items present in the cache.
func (c *Cache) Len() int {
c.lock.RLock()
defer c.lock.RUnlock()
return len(c.m)
}
// Delete removes the item out of the cache.
func (c *Cache) Delete(h util.Uint256) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.m, h)
}
// ReapStrangeBlocks drops blocks from cache that don't fit into the
// blkHeight-headHeight interval. Cache should only contain blocks that we
// expect to get and store.
func (c *Cache) ReapStrangeBlocks(blkHeight, headHeight uint32) {
c.lock.Lock()
defer c.lock.Unlock()
for i, b := range c.m {
block, ok := b.(*Block)
if ok && (block.Index < blkHeight || block.Index > headHeight) {
delete(c.m, i)
}
}
}

View file

@ -24,25 +24,6 @@ type BoltDBStore struct {
db *bbolt.DB db *bbolt.DB
} }
// BoltDBBatch simple batch implementation to satisfy the Store interface.
type BoltDBBatch struct {
mem map[*[]byte][]byte
}
// Len implements the Batch interface.
func (b *BoltDBBatch) Len() int {
return len(b.mem)
}
// Put implements the Batch interface.
func (b *BoltDBBatch) Put(k, v []byte) {
vcopy := make([]byte, len(v))
copy(vcopy, v)
kcopy := make([]byte, len(k))
copy(kcopy, k)
b.mem[&kcopy] = vcopy
}
// NewBoltDBStore returns a new ready to use BoltDB storage with created bucket. // NewBoltDBStore returns a new ready to use BoltDB storage with created bucket.
func NewBoltDBStore(cfg BoltDBOptions) (*BoltDBStore, error) { func NewBoltDBStore(cfg BoltDBOptions) (*BoltDBStore, error) {
var opts *bbolt.Options // should be exposed via BoltDBOptions if anything needed var opts *bbolt.Options // should be exposed via BoltDBOptions if anything needed
@ -94,7 +75,7 @@ func (s *BoltDBStore) Get(key []byte) (val []byte, err error) {
func (s *BoltDBStore) PutBatch(batch Batch) error { func (s *BoltDBStore) PutBatch(batch Batch) error {
return s.db.Batch(func(tx *bbolt.Tx) error { return s.db.Batch(func(tx *bbolt.Tx) error {
b := tx.Bucket(Bucket) b := tx.Bucket(Bucket)
for k, v := range batch.(*BoltDBBatch).mem { for k, v := range batch.(*MemoryBatch).m {
err := b.Put(*k, v) err := b.Put(*k, v)
if err != nil { if err != nil {
return err return err
@ -122,9 +103,7 @@ func (s *BoltDBStore) Seek(key []byte, f func(k, v []byte)) {
// Batch implements the Batch interface and returns a boltdb // Batch implements the Batch interface and returns a boltdb
// compatible Batch. // compatible Batch.
func (s *BoltDBStore) Batch() Batch { func (s *BoltDBStore) Batch() Batch {
return &BoltDBBatch{ return newMemoryBatch()
mem: make(map[*[]byte][]byte),
}
} }
// Close releases all db resources. // Close releases all db resources.

View file

@ -12,14 +12,14 @@ import (
func TestBoltDBBatch(t *testing.T) { func TestBoltDBBatch(t *testing.T) {
boltDB := BoltDBStore{} boltDB := BoltDBStore{}
want := &BoltDBBatch{mem: map[*[]byte][]byte{}} want := &MemoryBatch{m: map[*[]byte][]byte{}}
if got := boltDB.Batch(); !reflect.DeepEqual(got, want) { if got := boltDB.Batch(); !reflect.DeepEqual(got, want) {
t.Errorf("BoltDB Batch() = %v, want %v", got, want) t.Errorf("BoltDB Batch() = %v, want %v", got, want)
} }
} }
func TestBoltDBBatch_Len(t *testing.T) { func TestBoltDBBatch_Len(t *testing.T) {
batch := &BoltDBBatch{mem: map[*[]byte][]byte{}} batch := &MemoryBatch{m: map[*[]byte][]byte{}}
want := len(map[*[]byte][]byte{}) want := len(map[*[]byte][]byte{})
assert.Equal(t, want, batch.Len()) assert.Equal(t, want, batch.Len())
} }

View file

@ -41,7 +41,11 @@ func (s *LevelDBStore) Put(key, value []byte) error {
// Get implements the Store interface. // Get implements the Store interface.
func (s *LevelDBStore) Get(key []byte) ([]byte, error) { func (s *LevelDBStore) Get(key []byte) ([]byte, error) {
return s.db.Get(key, nil) value, err := s.db.Get(key, nil)
if err == leveldb.ErrNotFound {
err = ErrKeyNotFound
}
return value, err
} }
// PutBatch implements the Store interface. // PutBatch implements the Store interface.

View file

@ -9,7 +9,7 @@ import (
// MemoryStore is an in-memory implementation of a Store, mainly // MemoryStore is an in-memory implementation of a Store, mainly
// used for testing. Do not use MemoryStore in production. // used for testing. Do not use MemoryStore in production.
type MemoryStore struct { type MemoryStore struct {
*sync.RWMutex mut sync.RWMutex
mem map[string][]byte mem map[string][]byte
} }
@ -20,8 +20,11 @@ type MemoryBatch struct {
// Put implements the Batch interface. // Put implements the Batch interface.
func (b *MemoryBatch) Put(k, v []byte) { func (b *MemoryBatch) Put(k, v []byte) {
key := &k vcopy := make([]byte, len(v))
b.m[key] = v copy(vcopy, v)
kcopy := make([]byte, len(k))
copy(kcopy, k)
b.m[&kcopy] = vcopy
} }
// Len implements the Batch interface. // Len implements the Batch interface.
@ -32,36 +35,33 @@ func (b *MemoryBatch) Len() int {
// NewMemoryStore creates a new MemoryStore object. // NewMemoryStore creates a new MemoryStore object.
func NewMemoryStore() *MemoryStore { func NewMemoryStore() *MemoryStore {
return &MemoryStore{ return &MemoryStore{
RWMutex: new(sync.RWMutex),
mem: make(map[string][]byte), mem: make(map[string][]byte),
} }
} }
// Get implements the Store interface. // Get implements the Store interface.
func (s *MemoryStore) Get(key []byte) ([]byte, error) { func (s *MemoryStore) Get(key []byte) ([]byte, error) {
s.RLock() s.mut.RLock()
defer s.RUnlock() defer s.mut.RUnlock()
if val, ok := s.mem[makeKey(key)]; ok { if val, ok := s.mem[makeKey(key)]; ok {
return val, nil return val, nil
} }
return nil, ErrKeyNotFound return nil, ErrKeyNotFound
} }
// Put implements the Store interface. // Put implements the Store interface. Never returns an error.
func (s *MemoryStore) Put(key, value []byte) error { func (s *MemoryStore) Put(key, value []byte) error {
s.Lock() s.mut.Lock()
s.mem[makeKey(key)] = value s.mem[makeKey(key)] = value
s.Unlock() s.mut.Unlock()
return nil return nil
} }
// PutBatch implements the Store interface. // PutBatch implements the Store interface. Never returns an error.
func (s *MemoryStore) PutBatch(batch Batch) error { func (s *MemoryStore) PutBatch(batch Batch) error {
b := batch.(*MemoryBatch) b := batch.(*MemoryBatch)
for k, v := range b.m { for k, v := range b.m {
if err := s.Put(*k, v); err != nil { _ = s.Put(*k, v)
return err
}
} }
return nil return nil
} }
@ -78,16 +78,44 @@ func (s *MemoryStore) Seek(key []byte, f func(k, v []byte)) {
// Batch implements the Batch interface and returns a compatible Batch. // Batch implements the Batch interface and returns a compatible Batch.
func (s *MemoryStore) Batch() Batch { func (s *MemoryStore) Batch() Batch {
return newMemoryBatch()
}
// newMemoryBatch returns new memory batch.
func newMemoryBatch() *MemoryBatch {
return &MemoryBatch{ return &MemoryBatch{
m: make(map[*[]byte][]byte), m: make(map[*[]byte][]byte),
} }
} }
// Close implements Store interface and clears up memory. // Persist flushes all the MemoryStore contents into the (supposedly) persistent
// store provided via parameter.
func (s *MemoryStore) Persist(ps Store) (int, error) {
s.mut.Lock()
defer s.mut.Unlock()
batch := ps.Batch()
keys := 0
for k, v := range s.mem {
kb, _ := hex.DecodeString(k)
batch.Put(kb, v)
keys++
}
var err error
if keys != 0 {
err = ps.PutBatch(batch)
}
if err == nil {
s.mem = make(map[string][]byte)
}
return keys, err
}
// Close implements Store interface and clears up memory. Never returns an
// error.
func (s *MemoryStore) Close() error { func (s *MemoryStore) Close() error {
s.Lock() s.mut.Lock()
s.mem = nil s.mem = nil
s.Unlock() s.mut.Unlock()
return nil return nil
} }

View file

@ -75,3 +75,54 @@ func TestMemoryStore_Seek(t *testing.T) {
assert.Equal(t, value, v) assert.Equal(t, value, v)
}) })
} }
func TestMemoryStorePersist(t *testing.T) {
// temporary Store
ts := NewMemoryStore()
// persistent Store
ps := NewMemoryStore()
// persisting nothing should do nothing
c, err := ts.Persist(ps)
assert.Equal(t, nil, err)
assert.Equal(t, 0, c)
// persisting one key should result in one key in ps and nothing in ts
assert.NoError(t, ts.Put([]byte("key"), []byte("value")))
c, err = ts.Persist(ps)
assert.Equal(t, nil, err)
assert.Equal(t, 1, c)
v, err := ps.Get([]byte("key"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("value"), v)
v, err = ts.Get([]byte("key"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
// now we overwrite the previous `key` contents and also add `key2`,
assert.NoError(t, ts.Put([]byte("key"), []byte("newvalue")))
assert.NoError(t, ts.Put([]byte("key2"), []byte("value2")))
// this is to check that now key is written into the ps before we do
// persist
v, err = ps.Get([]byte("key2"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
// two keys should be persisted (one overwritten and one new) and
// available in the ps
c, err = ts.Persist(ps)
assert.Equal(t, nil, err)
assert.Equal(t, 2, c)
v, err = ts.Get([]byte("key"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
v, err = ts.Get([]byte("key2"))
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, []byte(nil), v)
v, err = ps.Get([]byte("key"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("newvalue"), v)
v, err = ps.Get([]byte("key2"))
assert.Equal(t, nil, err)
assert.Equal(t, []byte("value2"), v)
// we've persisted some values, make sure successive persist is a no-op
c, err = ts.Persist(ps)
assert.Equal(t, nil, err)
assert.Equal(t, 0, c)
}

View file

@ -18,28 +18,6 @@ type RedisStore struct {
client *redis.Client client *redis.Client
} }
// RedisBatch simple batch implementation to satisfy the Store interface.
type RedisBatch struct {
mem map[string]string
}
// Len implements the Batch interface.
func (b *RedisBatch) Len() int {
return len(b.mem)
}
// Put implements the Batch interface.
func (b *RedisBatch) Put(k, v []byte) {
b.mem[string(k)] = string(v)
}
// NewRedisBatch returns a new ready to use RedisBatch.
func NewRedisBatch() *RedisBatch {
return &RedisBatch{
mem: make(map[string]string),
}
}
// NewRedisStore returns an new initialized - ready to use RedisStore object. // NewRedisStore returns an new initialized - ready to use RedisStore object.
func NewRedisStore(cfg RedisDBOptions) (*RedisStore, error) { func NewRedisStore(cfg RedisDBOptions) (*RedisStore, error) {
c := redis.NewClient(&redis.Options{ c := redis.NewClient(&redis.Options{
@ -55,13 +33,16 @@ func NewRedisStore(cfg RedisDBOptions) (*RedisStore, error) {
// Batch implements the Store interface. // Batch implements the Store interface.
func (s *RedisStore) Batch() Batch { func (s *RedisStore) Batch() Batch {
return NewRedisBatch() return newMemoryBatch()
} }
// Get implements the Store interface. // Get implements the Store interface.
func (s *RedisStore) Get(k []byte) ([]byte, error) { func (s *RedisStore) Get(k []byte) ([]byte, error) {
val, err := s.client.Get(string(k)).Result() val, err := s.client.Get(string(k)).Result()
if err != nil { if err != nil {
if err == redis.Nil {
err = ErrKeyNotFound
}
return nil, err return nil, err
} }
return []byte(val), nil return []byte(val), nil
@ -76,8 +57,8 @@ func (s *RedisStore) Put(k, v []byte) error {
// PutBatch implements the Store interface. // PutBatch implements the Store interface.
func (s *RedisStore) PutBatch(b Batch) error { func (s *RedisStore) PutBatch(b Batch) error {
pipe := s.client.Pipeline() pipe := s.client.Pipeline()
for k, v := range b.(*RedisBatch).mem { for k, v := range b.(*MemoryBatch).m {
pipe.Set(k, v, 0) pipe.Set(string(*k), v, 0)
} }
_, err := pipe.Exec() _, err := pipe.Exec()
return err return err

View file

@ -1,7 +1,6 @@
package storage package storage
import ( import (
"reflect"
"testing" "testing"
"github.com/alicebob/miniredis" "github.com/alicebob/miniredis"
@ -9,13 +8,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestNewRedisBatch(t *testing.T) {
want := &RedisBatch{mem: map[string]string{}}
if got := NewRedisBatch(); !reflect.DeepEqual(got, want) {
t.Errorf("NewRedisBatch() = %v, want %v", got, want)
}
}
func TestNewRedisStore(t *testing.T) { func TestNewRedisStore(t *testing.T) {
redisMock, redisStore := prepareRedisMock(t) redisMock, redisStore := prepareRedisMock(t)
key := []byte("testKey") key := []byte("testKey")
@ -33,50 +25,10 @@ func TestNewRedisStore(t *testing.T) {
func TestRedisBatch_Len(t *testing.T) { func TestRedisBatch_Len(t *testing.T) {
want := len(map[string]string{}) want := len(map[string]string{})
b := &RedisBatch{ b := &MemoryBatch{
mem: map[string]string{}, m: map[*[]byte][]byte{},
} }
assert.Equal(t, len(b.mem), want) assert.Equal(t, len(b.m), want)
}
func TestRedisBatch_Put(t *testing.T) {
type args struct {
k []byte
v []byte
}
tests := []struct {
name string
args args
want *RedisBatch
}{
{"TestRedisBatch_Put_Strings",
args{
k: []byte("foo"),
v: []byte("bar"),
},
&RedisBatch{mem: map[string]string{"foo": "bar"}},
},
{"TestRedisBatch_Put_Numbers",
args{
k: []byte("123"),
v: []byte("456"),
},
&RedisBatch{mem: map[string]string{"123": "456"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := &RedisBatch{mem: map[string]string{}}
actual.Put(tt.args.k, tt.args.v)
assert.Equal(t, tt.want, actual)
})
}
}
func TestRedisStore_Batch(t *testing.T) {
want := &RedisBatch{mem: map[string]string{}}
actual := NewRedisBatch()
assert.Equal(t, want, actual)
} }
func TestRedisStore_GetAndPut(t *testing.T) { func TestRedisStore_GetAndPut(t *testing.T) {
@ -130,7 +82,7 @@ func TestRedisStore_GetAndPut(t *testing.T) {
} }
func TestRedisStore_PutBatch(t *testing.T) { func TestRedisStore_PutBatch(t *testing.T) {
batch := &RedisBatch{mem: map[string]string{"foo1": "bar1"}} batch := &MemoryBatch{m: map[*[]byte][]byte{&[]byte{'f', 'o', 'o', '1'}: []byte("bar1")}}
mock, redisStore := prepareRedisMock(t) mock, redisStore := prepareRedisMock(t)
err := redisStore.PutBatch(batch) err := redisStore.PutBatch(batch)
assert.Nil(t, err, "Error while PutBatch") assert.Nil(t, err, "Error while PutBatch")

View file

@ -13,11 +13,36 @@ import (
// coin state. // coin state.
type UnspentCoins map[util.Uint256]*UnspentCoinState type UnspentCoins map[util.Uint256]*UnspentCoinState
func (u UnspentCoins) getAndUpdate(s storage.Store, hash util.Uint256) (*UnspentCoinState, error) { // getAndUpdate retreives UnspentCoinState from temporary or persistent Store
// and return it. If it's not present in both stores, returns a new
// UnspentCoinState.
func (u UnspentCoins) getAndUpdate(ts storage.Store, ps storage.Store, hash util.Uint256) (*UnspentCoinState, error) {
if unspent, ok := u[hash]; ok { if unspent, ok := u[hash]; ok {
return unspent, nil return unspent, nil
} }
unspent, err := getUnspentCoinStateFromStore(ts, hash)
if err != nil {
if err != storage.ErrKeyNotFound {
return nil, err
}
unspent, err = getUnspentCoinStateFromStore(ps, hash)
if err != nil {
if err != storage.ErrKeyNotFound {
return nil, err
}
unspent = &UnspentCoinState{
states: []CoinState{},
}
}
}
u[hash] = unspent
return unspent, nil
}
// getUnspentCoinStateFromStore retrieves UnspentCoinState from the given store
func getUnspentCoinStateFromStore(s storage.Store, hash util.Uint256) (*UnspentCoinState, error) {
unspent := &UnspentCoinState{} unspent := &UnspentCoinState{}
key := storage.AppendPrefix(storage.STCoin, hash.BytesReverse()) key := storage.AppendPrefix(storage.STCoin, hash.BytesReverse())
if b, err := s.Get(key); err == nil { if b, err := s.Get(key); err == nil {
@ -27,12 +52,8 @@ func (u UnspentCoins) getAndUpdate(s storage.Store, hash util.Uint256) (*Unspent
return nil, fmt.Errorf("failed to decode (UnspentCoinState): %s", r.Err) return nil, fmt.Errorf("failed to decode (UnspentCoinState): %s", r.Err)
} }
} else { } else {
unspent = &UnspentCoinState{ return nil, err
states: []CoinState{},
} }
}
u[hash] = unspent
return unspent, nil return unspent, nil
} }

77
pkg/network/blockqueue.go Normal file
View file

@ -0,0 +1,77 @@
package network
import (
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/Workiva/go-datastructures/queue"
log "github.com/sirupsen/logrus"
)
type blockQueue struct {
queue *queue.PriorityQueue
checkBlocks chan struct{}
chain core.Blockchainer
}
func newBlockQueue(capacity int, bc core.Blockchainer) *blockQueue {
return &blockQueue{
queue: queue.NewPriorityQueue(capacity, false),
checkBlocks: make(chan struct{}, 1),
chain: bc,
}
}
func (bq *blockQueue) run() {
for {
_, ok := <-bq.checkBlocks
if !ok {
break
}
for {
item := bq.queue.Peek()
if item == nil {
break
}
minblock := item.(*core.Block)
if minblock.Index <= bq.chain.BlockHeight()+1 {
_, _ = bq.queue.Get(1)
if minblock.Index == bq.chain.BlockHeight()+1 {
err := bq.chain.AddBlock(minblock)
if err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
"blockHeight": bq.chain.BlockHeight(),
"nextIndex": minblock.Index,
}).Warn("blockQueue: failed adding block into the blockchain")
}
}
} else {
break
}
}
}
}
func (bq *blockQueue) putBlock(block *core.Block) error {
if bq.chain.BlockHeight() >= block.Index {
// can easily happen when fetching the same blocks from
// different peers, thus not considered as error
return nil
}
err := bq.queue.Put(block)
select {
case bq.checkBlocks <- struct{}{}:
// ok, signalled to goroutine processing queue
default:
// it's already busy processing blocks
}
return err
}
func (bq *blockQueue) discard() {
close(bq.checkBlocks)
bq.queue.Dispose()
}
func (bq *blockQueue) length() int {
return bq.queue.Len()
}

View file

@ -0,0 +1,71 @@
package network
import (
"testing"
"time"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/stretchr/testify/assert"
)
func TestBlockQueue(t *testing.T) {
chain := &testChain{}
// notice, it's not yet running
bq := newBlockQueue(0, chain)
blocks := make([]*core.Block, 11)
for i := 1; i < 11; i++ {
blocks[i] = &core.Block{BlockBase: core.BlockBase{Index: uint32(i)}}
}
// not the ones expected currently
for i := 3; i < 5; i++ {
assert.NoError(t, bq.putBlock(blocks[i]))
}
// nothing should be put into the blockchain
assert.Equal(t, uint32(0), chain.BlockHeight())
assert.Equal(t, 2, bq.length())
// now added expected ones (with duplicates)
for i := 1; i < 5; i++ {
assert.NoError(t, bq.putBlock(blocks[i]))
}
// but they're still not put into the blockchain, because bq isn't running
assert.Equal(t, uint32(0), chain.BlockHeight())
assert.Equal(t, 4, bq.length())
go bq.run()
// run() is asynchronous, so we need some kind of timeout anyway and this is the simplest one
for i := 0; i < 5; i++ {
if chain.BlockHeight() != 4 {
time.Sleep(time.Second)
}
}
assert.Equal(t, 0, bq.length())
assert.Equal(t, uint32(4), chain.BlockHeight())
// put some old blocks
for i := 1; i < 5; i++ {
assert.NoError(t, bq.putBlock(blocks[i]))
}
assert.Equal(t, 0, bq.length())
assert.Equal(t, uint32(4), chain.BlockHeight())
// unexpected blocks with run() active
assert.NoError(t, bq.putBlock(blocks[8]))
assert.Equal(t, 1, bq.length())
assert.Equal(t, uint32(4), chain.BlockHeight())
assert.NoError(t, bq.putBlock(blocks[7]))
assert.Equal(t, 2, bq.length())
assert.Equal(t, uint32(4), chain.BlockHeight())
// sparse put
assert.NoError(t, bq.putBlock(blocks[10]))
assert.Equal(t, 3, bq.length())
assert.Equal(t, uint32(4), chain.BlockHeight())
assert.NoError(t, bq.putBlock(blocks[6]))
assert.NoError(t, bq.putBlock(blocks[5]))
// run() is asynchronous, so we need some kind of timeout anyway and this is the simplest one
for i := 0; i < 5; i++ {
if chain.BlockHeight() != 8 {
time.Sleep(time.Second)
}
}
assert.Equal(t, 1, bq.length())
assert.Equal(t, uint32(8), chain.BlockHeight())
bq.discard()
assert.Equal(t, 0, bq.length())
}

View file

@ -3,6 +3,7 @@ package network
import ( import (
"math/rand" "math/rand"
"net" "net"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -13,7 +14,9 @@ import (
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
) )
type testChain struct{} type testChain struct {
blockheight uint32
}
func (chain testChain) GetConfig() config.ProtocolConfiguration { func (chain testChain) GetConfig() config.ProtocolConfiguration {
panic("TODO") panic("TODO")
@ -38,11 +41,14 @@ func (chain testChain) NetworkFee(t *transaction.Transaction) util.Fixed8 {
func (chain testChain) AddHeaders(...*core.Header) error { func (chain testChain) AddHeaders(...*core.Header) error {
panic("TODO") panic("TODO")
} }
func (chain testChain) AddBlock(*core.Block) error { func (chain *testChain) AddBlock(block *core.Block) error {
panic("TODO") if block.Index == chain.blockheight+1 {
atomic.StoreUint32(&chain.blockheight, block.Index)
} }
func (chain testChain) BlockHeight() uint32 { return nil
return 0 }
func (chain *testChain) BlockHeight() uint32 {
return atomic.LoadUint32(&chain.blockheight)
} }
func (chain testChain) HeaderHeight() uint32 { func (chain testChain) HeaderHeight() uint32 {
return 0 return 0
@ -168,7 +174,7 @@ func (p *localPeer) Handshaked() bool {
func newTestServer() *Server { func newTestServer() *Server {
return &Server{ return &Server{
ServerConfig: ServerConfig{}, ServerConfig: ServerConfig{},
chain: testChain{}, chain: &testChain{},
transport: localTransport{}, transport: localTransport{},
discovery: testDiscovery{}, discovery: testDiscovery{},
id: rand.Uint32(), id: rand.Uint32(),

View file

@ -45,6 +45,7 @@ type (
transport Transporter transport Transporter
discovery Discoverer discovery Discoverer
chain core.Blockchainer chain core.Blockchainer
bQueue *blockQueue
lock sync.RWMutex lock sync.RWMutex
peers map[Peer]bool peers map[Peer]bool
@ -66,6 +67,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer) *Server {
s := &Server{ s := &Server{
ServerConfig: config, ServerConfig: config,
chain: chain, chain: chain,
bQueue: newBlockQueue(maxBlockBatch, chain),
id: rand.Uint32(), id: rand.Uint32(),
quit: make(chan struct{}), quit: make(chan struct{}),
addrReq: make(chan *Message, minPeers), addrReq: make(chan *Message, minPeers),
@ -97,6 +99,7 @@ func (s *Server) Start(errChan chan error) {
s.discovery.BackFill(s.Seeds...) s.discovery.BackFill(s.Seeds...)
go s.bQueue.run()
go s.transport.Accept() go s.transport.Accept()
s.run() s.run()
} }
@ -106,6 +109,7 @@ func (s *Server) Shutdown() {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"peers": s.PeerCount(), "peers": s.PeerCount(),
}).Info("shutting down server") }).Info("shutting down server")
s.bQueue.discard()
close(s.quit) close(s.quit)
} }
@ -273,10 +277,7 @@ func (s *Server) handleHeadersCmd(p Peer, headers *payload.Headers) {
// handleBlockCmd processes the received block received from its peer. // handleBlockCmd processes the received block received from its peer.
func (s *Server) handleBlockCmd(p Peer, block *core.Block) error { func (s *Server) handleBlockCmd(p Peer, block *core.Block) error {
if !s.chain.HasBlock(block.Hash()) { return s.bQueue.putBlock(block)
return s.chain.AddBlock(block)
}
return nil
} }
// handleInvCmd will process the received inventory. // handleInvCmd will process the received inventory.
@ -329,7 +330,7 @@ func (s *Server) requestBlocks(p Peer) error {
hashStart = s.chain.BlockHeight() + 1 hashStart = s.chain.BlockHeight() + 1
headerHeight = s.chain.HeaderHeight() headerHeight = s.chain.HeaderHeight()
) )
for hashStart < headerHeight && len(hashes) < maxBlockBatch { for hashStart <= headerHeight && len(hashes) < maxBlockBatch {
hash := s.chain.GetHeaderHash(int(hashStart)) hash := s.chain.GetHeaderHash(int(hashStart))
hashes = append(hashes, hash) hashes = append(hashes, hash)
hashStart++ hashStart++

View file

@ -167,7 +167,6 @@ func initBlocks(t *testing.T, chain *core.Blockchain) {
for i := 0; i < len(blocks); i++ { for i := 0; i < len(blocks); i++ {
require.NoError(t, chain.AddBlock(blocks[i])) require.NoError(t, chain.AddBlock(blocks[i]))
} }
require.NoError(t, chain.Persist(context.Background()))
} }
func makeBlocks(n int) []*core.Block { func makeBlocks(n int) []*core.Block {

View file

@ -13,7 +13,6 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestRPC(t *testing.T) { func TestRPC(t *testing.T) {
@ -22,9 +21,6 @@ func TestRPC(t *testing.T) {
chain, handler := initServerWithInMemoryChain(ctx, t) chain, handler := initServerWithInMemoryChain(ctx, t)
defer func() {
require.NoError(t, chain.Close())
}()
t.Run("getbestblockhash", func(t *testing.T) { t.Run("getbestblockhash", func(t *testing.T) {
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getbestblockhash", "params": []}` rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getbestblockhash", "params": []}`
body := doRPCCall(rpc, handler, t) body := doRPCCall(rpc, handler, t)