diff --git a/go.mod b/go.mod index 5aa463fdb..6f1e947d8 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/CityOfZion/neo-go require ( + github.com/Workiva/go-datastructures v1.0.50 github.com/abiosoft/ishell v2.0.0+incompatible // indirect github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db // indirect github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 // indirect diff --git a/go.sum b/go.sum index 4ba89b8ea..3ec2e92ef 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/Workiva/go-datastructures v1.0.50 h1:slDmfW6KCHcC7U+LP3DDBbm4fqTwZGn1beOFPfGaLvo= +github.com/Workiva/go-datastructures v1.0.50/go.mod h1:Z+F2Rca0qCsVYDS8z7bAGm8f3UkzuWYS/oBZz5a7VVA= github.com/abiosoft/ishell v2.0.0+incompatible h1:zpwIuEHc37EzrsIYah3cpevrIc8Oma7oZPxr03tlmmw= github.com/abiosoft/ishell v2.0.0+incompatible/go.mod h1:HQR9AqF2R3P4XXpMpI0NAzgHf/aS6+zVXRj14cVk9qg= github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db h1:CjPUSXOiYptLbTdr1RceuZgSFDQ7U15ITERUGrUORx8= diff --git a/pkg/core/account_state.go b/pkg/core/account_state.go index 87c9771a0..ae34d573a 100644 --- a/pkg/core/account_state.go +++ b/pkg/core/account_state.go @@ -12,25 +12,46 @@ import ( // Accounts is mapping between a account address and AccountState. type Accounts map[util.Uint160]*AccountState -func (a Accounts) getAndUpdate(s storage.Store, hash util.Uint160) (*AccountState, error) { +// getAndUpdate retrieves AccountState from temporary or persistent Store +// or creates a new one if it doesn't exist. +func (a Accounts) getAndUpdate(ts storage.Store, ps storage.Store, hash util.Uint160) (*AccountState, error) { if account, ok := a[hash]; ok { return account, nil } - account := &AccountState{} + account, err := getAccountStateFromStore(ts, hash) + if err != nil { + if err != storage.ErrKeyNotFound { + return nil, err + } + account, err = getAccountStateFromStore(ps, hash) + if err != nil { + if err != storage.ErrKeyNotFound { + return nil, err + } + account = NewAccountState(hash) + } + } + + a[hash] = account + return account, nil +} + +// getAccountStateFromStore returns AccountState from the given Store if it's +// present there. Returns nil otherwise. +func getAccountStateFromStore(s storage.Store, hash util.Uint160) (*AccountState, error) { + var account *AccountState key := storage.AppendPrefix(storage.STAccount, hash.Bytes()) - if b, err := s.Get(key); err == nil { + b, err := s.Get(key) + if err == nil { + account = new(AccountState) r := io.NewBinReaderFromBuf(b) account.DecodeBinary(r) if r.Err != nil { return nil, fmt.Errorf("failed to decode (AccountState): %s", r.Err) } - } else { - account = NewAccountState(hash) } - - a[hash] = account - return account, nil + return account, err } // commit writes all account states to the given Batch. diff --git a/pkg/core/block.go b/pkg/core/block.go index 27ec6b225..97fdeabce 100644 --- a/pkg/core/block.go +++ b/pkg/core/block.go @@ -5,6 +5,7 @@ import ( "github.com/CityOfZion/neo-go/pkg/crypto" "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/util" + "github.com/Workiva/go-datastructures/queue" log "github.com/sirupsen/logrus" ) @@ -132,3 +133,16 @@ func (b *Block) EncodeBinary(bw *io.BinWriter) { tx.EncodeBinary(bw) } } + +// Compare implements the queue Item interface. +func (b *Block) Compare(item queue.Item) int { + other := item.(*Block) + switch { + case b.Index > other.Index: + return 1 + case b.Index == other.Index: + return 0 + default: + return -1 + } +} diff --git a/pkg/core/block_test.go b/pkg/core/block_test.go index a26007d7e..164a0ab88 100644 --- a/pkg/core/block_test.go +++ b/pkg/core/block_test.go @@ -259,3 +259,12 @@ func TestBlockSizeCalculation(t *testing.T) { assert.Equal(t, 7360, len(benc)) assert.Equal(t, rawBlock, hex.EncodeToString(benc)) } + +func TestBlockCompare(t *testing.T) { + b1 := Block{BlockBase: BlockBase{Index: 1}} + b2 := Block{BlockBase: BlockBase{Index: 2}} + b3 := Block{BlockBase: BlockBase{Index: 3}} + assert.Equal(t, 1, b2.Compare(&b1)) + assert.Equal(t, 0, b2.Compare(&b2)) + assert.Equal(t, -1, b2.Compare(&b3)) +} diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 24f98ffcb..03e75bedd 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -37,16 +37,20 @@ type Blockchain struct { // Any object that satisfies the BlockchainStorer interface. storage.Store + // In-memory storage to be persisted into the storage.Store + memStore *storage.MemoryStore + // Current index/height of the highest block. // Read access should always be called by BlockHeight(). - // Write access should only happen in Persist(). + // Write access should only happen in storeBlock(). blockHeight uint32 + // Current persisted block count. + persistedHeight uint32 + // Number of headers stored in the chain file. storedHeaderCount uint32 - blockCache *Cache - // All operation on headerList must be called from an // headersOp to be routine safe. headerList *HeaderHashList @@ -69,9 +73,9 @@ func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration) (*Blockcha bc := &Blockchain{ config: cfg, Store: s, + memStore: storage.NewMemoryStore(), headersOp: make(chan headersOpFunc), headersOpDone: make(chan struct{}), - blockCache: NewCache(), verifyBlocks: false, memPool: NewMemPool(50000), } @@ -96,7 +100,7 @@ func (bc *Blockchain) init() error { return err } bc.headerList = NewHeaderHashList(genesisBlock.Hash()) - return bc.persistBlock(genesisBlock) + return bc.storeBlock(genesisBlock) } if ver != version { return fmt.Errorf("storage version mismatch betweeen %s and %s", version, ver) @@ -112,6 +116,7 @@ func (bc *Blockchain) init() error { return err } bc.blockHeight = bHeight + bc.persistedHeight = bHeight hashes, err := storage.HeaderHashes(bc.Store) if err != nil { @@ -144,8 +149,11 @@ func (bc *Blockchain) init() error { } headerSliceReverse(headers) - if err := bc.AddHeaders(headers...); err != nil { - return err + for _, h := range headers { + if !h.Verify() { + return fmt.Errorf("bad header %d/%s in the storage", h.Index, h.Hash()) + } + bc.headerList.Add(h.Hash()) } } @@ -157,6 +165,11 @@ func (bc *Blockchain) Run(ctx context.Context) { persistTimer := time.NewTimer(persistInterval) defer func() { persistTimer.Stop() + if err := bc.persist(ctx); err != nil { + log.Warnf("failed to persist: %s", err) + } + // never fails + _ = bc.memStore.Close() if err := bc.Store.Close(); err != nil { log.Warnf("failed to close db: %s", err) } @@ -170,7 +183,7 @@ func (bc *Blockchain) Run(ctx context.Context) { bc.headersOpDone <- struct{}{} case <-persistTimer.C: go func() { - err := bc.Persist(ctx) + err := bc.persist(ctx) if err != nil { log.Warnf("failed to persist blockchain: %s", err) } @@ -180,24 +193,24 @@ func (bc *Blockchain) Run(ctx context.Context) { } } -// AddBlock processes the given block and will add it to the cache so it -// can be persisted. +// AddBlock accepts successive block for the Blockchain, verifies it and +// stores internally. Eventually it will be persisted to the backing storage. func (bc *Blockchain) AddBlock(block *Block) error { - if !bc.blockCache.Has(block.Hash()) { - bc.blockCache.Add(block.Hash(), block) + expectedHeight := bc.BlockHeight() + 1 + if expectedHeight != block.Index { + return fmt.Errorf("expected block %d, but passed block %d", expectedHeight, block.Index) + } + if bc.verifyBlocks && !block.Verify(false) { + return fmt.Errorf("block %s is invalid", block.Hash()) } - headerLen := bc.headerListLen() - if int(block.Index-1) >= headerLen { - return nil - } if int(block.Index) == headerLen { - if bc.verifyBlocks && !block.Verify(false) { - return fmt.Errorf("block %s is invalid", block.Hash()) + err := bc.AddHeaders(block.Header()) + if err != nil { + return err } - return bc.AddHeaders(block.Header()) } - return nil + return bc.storeBlock(block) } // AddHeaders will process the given headers and add them to the @@ -205,7 +218,7 @@ func (bc *Blockchain) AddBlock(block *Block) error { func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) { var ( start = time.Now() - batch = bc.Batch() + batch = bc.memStore.Batch() ) bc.headersOp <- func(headerList *HeaderHashList) { @@ -230,7 +243,7 @@ func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) { } if batch.Len() > 0 { - if err = bc.PutBatch(batch); err != nil { + if err = bc.memStore.PutBatch(batch); err != nil { return } log.WithFields(log.Fields{ @@ -273,13 +286,13 @@ func (bc *Blockchain) processHeader(h *Header, batch storage.Batch, headerList * return nil } -// TODO: persistBlock needs some more love, its implemented as in the original +// TODO: storeBlock needs some more love, its implemented as in the original // project. This for the sake of development speed and understanding of what // is happening here, quite allot as you can see :). If things are wired together // and all tests are in place, we can make a more optimized and cleaner implementation. -func (bc *Blockchain) persistBlock(block *Block) error { +func (bc *Blockchain) storeBlock(block *Block) error { var ( - batch = bc.Batch() + batch = bc.memStore.Batch() unspentCoins = make(UnspentCoins) spentCoins = make(SpentCoins) accounts = make(Accounts) @@ -301,7 +314,7 @@ func (bc *Blockchain) persistBlock(block *Block) error { // Process TX outputs. for _, output := range tx.Outputs { - account, err := accounts.getAndUpdate(bc.Store, output.ScriptHash) + account, err := accounts.getAndUpdate(bc.memStore, bc.Store, output.ScriptHash) if err != nil { return err } @@ -319,14 +332,14 @@ func (bc *Blockchain) persistBlock(block *Block) error { return fmt.Errorf("could not find previous TX: %s", prevHash) } for _, input := range inputs { - unspent, err := unspentCoins.getAndUpdate(bc.Store, input.PrevHash) + unspent, err := unspentCoins.getAndUpdate(bc.memStore, bc.Store, input.PrevHash) if err != nil { return err } unspent.states[input.PrevIndex] = CoinStateSpent prevTXOutput := prevTX.Outputs[input.PrevIndex] - account, err := accounts.getAndUpdate(bc.Store, prevTXOutput.ScriptHash) + account, err := accounts.getAndUpdate(bc.memStore, bc.Store, prevTXOutput.ScriptHash) if err != nil { return err } @@ -388,7 +401,7 @@ func (bc *Blockchain) persistBlock(block *Block) error { if err := assets.commit(batch); err != nil { return err } - if err := bc.PutBatch(batch); err != nil { + if err := bc.memStore.PutBatch(batch); err != nil { return err } @@ -396,63 +409,37 @@ func (bc *Blockchain) persistBlock(block *Block) error { return nil } -//Persist starts persist loop. -func (bc *Blockchain) Persist(ctx context.Context) (err error) { +// persist flushes current in-memory store contents to the persistent storage. +func (bc *Blockchain) persist(ctx context.Context) error { var ( start = time.Now() persisted = 0 - lenCache = bc.blockCache.Len() + err error ) - if lenCache == 0 { - return nil + persisted, err = bc.memStore.Persist(bc.Store) + if err != nil { + return err } - - bc.headersOp <- func(headerList *HeaderHashList) { - for i := 0; i < lenCache; i++ { - if uint32(headerList.Len()) <= bc.BlockHeight() { - return - } - hash := headerList.Get(int(bc.BlockHeight() + 1)) - if block, ok := bc.blockCache.GetBlock(hash); ok { - if err = bc.persistBlock(block); err != nil { - return - } - bc.blockCache.Delete(hash) - persisted++ - } else { - // no next block in the cache, no reason to continue looping - break - } - } - } - - select { - case <-ctx.Done(): - return - case <-bc.headersOpDone: - // + bHeight, err := storage.CurrentBlockHeight(bc.Store) + if err != nil { + return err } + oldHeight := atomic.SwapUint32(&bc.persistedHeight, bHeight) + diff := bHeight - oldHeight if persisted > 0 { log.WithFields(log.Fields{ - "persisted": persisted, - "headerHeight": bc.HeaderHeight(), - "blockHeight": bc.BlockHeight(), - "took": time.Since(start), + "persistedBlocks": diff, + "persistedKeys": persisted, + "headerHeight": bc.HeaderHeight(), + "blockHeight": bc.BlockHeight(), + "persistedHeight": bc.persistedHeight, + "took": time.Since(start), }).Info("blockchain persist completed") - } else { - // So we have some blocks in cache but can't persist them? - // Either there are some stale blocks there or the other way - // around (which was seen in practice) --- there are some fresh - // blocks that we can't persist yet. Some of the latter can be useful - // or can be bogus (higher than the header height we expect at - // the moment). So try to reap oldies and strange newbies, if - // there are any. - bc.blockCache.ReapStrangeBlocks(bc.BlockHeight(), bc.HeaderHeight()) } - return + return nil } func (bc *Blockchain) headerListLen() (n int) { @@ -468,9 +455,18 @@ func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transactio if tx, ok := bc.memPool.TryGetValue(hash); ok { return tx, 0, nil // the height is not actually defined for memPool transaction. Not sure if zero is a good number in this case. } + tx, height, err := getTransactionFromStore(bc.memStore, hash) + if err != nil { + tx, height, err = getTransactionFromStore(bc.Store, hash) + } + return tx, height, err +} +// getTransactionFromStore returns Transaction and its height by the given hash +// if it exists in the store. +func getTransactionFromStore(s storage.Store, hash util.Uint256) (*transaction.Transaction, uint32, error) { key := storage.AppendPrefix(storage.DataTransaction, hash.BytesReverse()) - b, err := bc.Get(key) + b, err := s.Get(key) if err != nil { return nil, 0, err } @@ -490,8 +486,23 @@ func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transactio // GetBlock returns a Block by the given hash. func (bc *Blockchain) GetBlock(hash util.Uint256) (*Block, error) { + block, err := getBlockFromStore(bc.memStore, hash) + if err != nil { + block, err = getBlockFromStore(bc.Store, hash) + if err != nil { + return nil, err + } + } + if len(block.Transactions) == 0 { + return nil, fmt.Errorf("only header is available") + } + return block, nil +} + +// getBlockFromStore returns Block by the given hash if it exists in the store. +func getBlockFromStore(s storage.Store, hash util.Uint256) (*Block, error) { key := storage.AppendPrefix(storage.DataBlock, hash.BytesReverse()) - b, err := bc.Get(key) + b, err := s.Get(key) if err != nil { return nil, err } @@ -499,20 +510,24 @@ func (bc *Blockchain) GetBlock(hash util.Uint256) (*Block, error) { if err != nil { return nil, err } - // TODO: persist TX first before we can handle this logic. - // if len(block.Transactions) == 0 { - // return nil, fmt.Errorf("block has no TX") - // } - return block, nil + return block, err } // GetHeader returns data block header identified with the given hash value. func (bc *Blockchain) GetHeader(hash util.Uint256) (*Header, error) { - b, err := bc.Get(storage.AppendPrefix(storage.DataBlock, hash.BytesReverse())) + header, err := getHeaderFromStore(bc.memStore, hash) if err != nil { - return nil, err + header, err = getHeaderFromStore(bc.Store, hash) + if err != nil { + return nil, err + } } - block, err := NewBlockFromTrimmedBytes(b) + return header, err +} + +// getHeaderFromStore returns Header by the given hash from the store. +func getHeaderFromStore(s storage.Store, hash util.Uint256) (*Header, error) { + block, err := getBlockFromStore(s, hash) if err != nil { return nil, err } @@ -522,12 +537,16 @@ func (bc *Blockchain) GetHeader(hash util.Uint256) (*Header, error) { // HasTransaction return true if the blockchain contains he given // transaction hash. func (bc *Blockchain) HasTransaction(hash util.Uint256) bool { - if bc.memPool.ContainsKey(hash) { - return true - } + return bc.memPool.ContainsKey(hash) || + checkTransactionInStore(bc.memStore, hash) || + checkTransactionInStore(bc.Store, hash) +} +// checkTransactionInStore returns true if the given store contains the given +// Transaction hash. +func checkTransactionInStore(s storage.Store, hash util.Uint256) bool { key := storage.AppendPrefix(storage.DataTransaction, hash.BytesReverse()) - if _, err := bc.Get(key); err == nil { + if _, err := s.Get(key); err == nil { return true } return false @@ -582,31 +601,43 @@ func (bc *Blockchain) HeaderHeight() uint32 { // GetAssetState returns asset state from its assetID func (bc *Blockchain) GetAssetState(assetID util.Uint256) *AssetState { - var as *AssetState - bc.Store.Seek(storage.STAsset.Bytes(), func(k, v []byte) { - var a AssetState - r := io.NewBinReaderFromBuf(v) - a.DecodeBinary(r) - if r.Err == nil && a.ID == assetID { - as = &a - } - }) - + as := getAssetStateFromStore(bc.memStore, assetID) + if as == nil { + as = getAssetStateFromStore(bc.Store, assetID) + } return as } +// getAssetStateFromStore returns given asset state as recorded in the given +// store. +func getAssetStateFromStore(s storage.Store, assetID util.Uint256) *AssetState { + key := storage.AppendPrefix(storage.STAsset, assetID.Bytes()) + asEncoded, err := s.Get(key) + if err != nil { + return nil + } + var a AssetState + r := io.NewBinReaderFromBuf(asEncoded) + a.DecodeBinary(r) + if r.Err != nil || a.ID != assetID { + return nil + } + + return &a +} + // GetAccountState returns the account state from its script hash func (bc *Blockchain) GetAccountState(scriptHash util.Uint160) *AccountState { - var as *AccountState - bc.Store.Seek(storage.STAccount.Bytes(), func(k, v []byte) { - var a AccountState - r := io.NewBinReaderFromBuf(v) - a.DecodeBinary(r) - if r.Err == nil && a.ScriptHash == scriptHash { - as = &a + as, err := getAccountStateFromStore(bc.memStore, scriptHash) + if as == nil { + if err != storage.ErrKeyNotFound { + log.Warnf("failed to get account state: %s", err) } - }) - + as, err = getAccountStateFromStore(bc.Store, scriptHash) + if as == nil && err != storage.ErrKeyNotFound { + log.Warnf("failed to get account state: %s", err) + } + } return as } diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 018bb3096..1943d1610 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -13,9 +13,6 @@ import ( func TestAddHeaders(t *testing.T) { bc := newTestChain(t) - defer func() { - require.NoError(t, bc.Close()) - }() h1 := newBlock(1).Header() h2 := newBlock(2).Header() h3 := newBlock(3).Header() @@ -24,7 +21,6 @@ func TestAddHeaders(t *testing.T) { t.Fatal(err) } - assert.Equal(t, 0, bc.blockCache.Len()) assert.Equal(t, h3.Index, bc.HeaderHeight()) assert.Equal(t, uint32(0), bc.BlockHeight()) assert.Equal(t, h3.Hash(), bc.CurrentHeaderHash()) @@ -41,9 +37,6 @@ func TestAddHeaders(t *testing.T) { func TestAddBlock(t *testing.T) { bc := newTestChain(t) - defer func() { - require.NoError(t, bc.Close()) - }() blocks := []*Block{ newBlock(1), newBlock(2), @@ -57,15 +50,11 @@ func TestAddBlock(t *testing.T) { } lastBlock := blocks[len(blocks)-1] - assert.Equal(t, 3, bc.blockCache.Len()) assert.Equal(t, lastBlock.Index, bc.HeaderHeight()) assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash()) - t.Log(bc.blockCache) - - if err := bc.Persist(context.Background()); err != nil { - t.Fatal(err) - } + // This one tests persisting blocks, so it does need to persist() + require.NoError(t, bc.persist(context.Background())) for _, block := range blocks { key := storage.AppendPrefix(storage.DataBlock, block.Hash().BytesReverse()) @@ -76,33 +65,30 @@ func TestAddBlock(t *testing.T) { assert.Equal(t, lastBlock.Index, bc.BlockHeight()) assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash()) - assert.Equal(t, 0, bc.blockCache.Len()) } func TestGetHeader(t *testing.T) { bc := newTestChain(t) - defer func() { - require.NoError(t, bc.Close()) - }() block := newBlock(1) err := bc.AddBlock(block) assert.Nil(t, err) - hash := block.Hash() - header, err := bc.GetHeader(hash) - require.NoError(t, err) - assert.Equal(t, block.Header(), header) + // Test unpersisted and persisted access + for i := 0; i < 2; i++ { + hash := block.Hash() + header, err := bc.GetHeader(hash) + require.NoError(t, err) + assert.Equal(t, block.Header(), header) - block = newBlock(2) - _, err = bc.GetHeader(block.Hash()) - assert.Error(t, err) + b2 := newBlock(2) + _, err = bc.GetHeader(b2.Hash()) + assert.Error(t, err) + assert.NoError(t, bc.persist(context.Background())) + } } func TestGetBlock(t *testing.T) { bc := newTestChain(t) - defer func() { - require.NoError(t, bc.Close()) - }() blocks := makeBlocks(100) for i := 0; i < len(blocks); i++ { @@ -111,21 +97,22 @@ func TestGetBlock(t *testing.T) { } } - for i := 0; i < len(blocks); i++ { - block, err := bc.GetBlock(blocks[i].Hash()) - if err != nil { - t.Fatal(err) + // Test unpersisted and persisted access + for j := 0; j < 2; j++ { + for i := 0; i < len(blocks); i++ { + block, err := bc.GetBlock(blocks[i].Hash()) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, blocks[i].Index, block.Index) + assert.Equal(t, blocks[i].Hash(), block.Hash()) } - assert.Equal(t, blocks[i].Index, block.Index) - assert.Equal(t, blocks[i].Hash(), block.Hash()) + assert.NoError(t, bc.persist(context.Background())) } } func TestHasBlock(t *testing.T) { bc := newTestChain(t) - defer func() { - require.NoError(t, bc.Close()) - }() blocks := makeBlocks(50) for i := 0; i < len(blocks); i++ { @@ -133,37 +120,39 @@ func TestHasBlock(t *testing.T) { t.Fatal(err) } } - assert.Nil(t, bc.Persist(context.Background())) - for i := 0; i < len(blocks); i++ { - assert.True(t, bc.HasBlock(blocks[i].Hash())) + // Test unpersisted and persisted access + for j := 0; j < 2; j++ { + for i := 0; i < len(blocks); i++ { + assert.True(t, bc.HasBlock(blocks[i].Hash())) + } + newBlock := newBlock(51) + assert.False(t, bc.HasBlock(newBlock.Hash())) + assert.NoError(t, bc.persist(context.Background())) } - - newBlock := newBlock(51) - assert.False(t, bc.HasBlock(newBlock.Hash())) } func TestGetTransaction(t *testing.T) { + b1 := getDecodedBlock(t, 1) block := getDecodedBlock(t, 2) bc := newTestChain(t) - defer func() { - require.NoError(t, bc.Close()) - }() + assert.Nil(t, bc.AddBlock(b1)) assert.Nil(t, bc.AddBlock(block)) - assert.Nil(t, bc.persistBlock(block)) - tx, height, err := bc.GetTransaction(block.Transactions[0].Hash()) - if err != nil { - t.Fatal(err) + // Test unpersisted and persisted access + for j := 0; j < 2; j++ { + tx, height, err := bc.GetTransaction(block.Transactions[0].Hash()) + require.Nil(t, err) + assert.Equal(t, block.Index, height) + assert.Equal(t, block.Transactions[0], tx) + assert.Equal(t, 10, io.GetVarSize(tx)) + assert.Equal(t, 1, io.GetVarSize(tx.Attributes)) + assert.Equal(t, 1, io.GetVarSize(tx.Inputs)) + assert.Equal(t, 1, io.GetVarSize(tx.Outputs)) + assert.Equal(t, 1, io.GetVarSize(tx.Scripts)) + assert.NoError(t, bc.persist(context.Background())) } - assert.Equal(t, block.Index, height) - assert.Equal(t, block.Transactions[0], tx) - assert.Equal(t, 10, io.GetVarSize(tx)) - assert.Equal(t, 1, io.GetVarSize(tx.Attributes)) - assert.Equal(t, 1, io.GetVarSize(tx.Inputs)) - assert.Equal(t, 1, io.GetVarSize(tx.Outputs)) - assert.Equal(t, 1, io.GetVarSize(tx.Scripts)) } func newTestChain(t *testing.T) *Blockchain { diff --git a/pkg/core/cache.go b/pkg/core/cache.go deleted file mode 100644 index c2141851e..000000000 --- a/pkg/core/cache.go +++ /dev/null @@ -1,87 +0,0 @@ -package core - -import ( - "sync" - - "github.com/CityOfZion/neo-go/pkg/util" -) - -// Cache is data structure with fixed type key of Uint256, but has a -// generic value. Used for block, tx and header cache types. -type Cache struct { - lock sync.RWMutex - m map[util.Uint256]interface{} -} - -// NewCache returns a ready to use Cache object. -func NewCache() *Cache { - return &Cache{ - m: make(map[util.Uint256]interface{}), - } -} - -// GetBlock will return a Block type from the cache. -func (c *Cache) GetBlock(h util.Uint256) (block *Block, ok bool) { - c.lock.RLock() - defer c.lock.RUnlock() - return c.getBlock(h) -} - -func (c *Cache) getBlock(h util.Uint256) (block *Block, ok bool) { - if v, b := c.m[h]; b { - block, ok = v.(*Block) - return - } - return -} - -// Add adds the given hash along with its value to the cache. -func (c *Cache) Add(h util.Uint256, v interface{}) { - c.lock.Lock() - defer c.lock.Unlock() - c.add(h, v) -} - -func (c *Cache) add(h util.Uint256, v interface{}) { - c.m[h] = v -} - -func (c *Cache) has(h util.Uint256) bool { - _, ok := c.m[h] - return ok -} - -// Has returns whether the cache contains the given hash. -func (c *Cache) Has(h util.Uint256) bool { - c.lock.Lock() - defer c.lock.Unlock() - return c.has(h) -} - -// Len return the number of items present in the cache. -func (c *Cache) Len() int { - c.lock.RLock() - defer c.lock.RUnlock() - return len(c.m) -} - -// Delete removes the item out of the cache. -func (c *Cache) Delete(h util.Uint256) { - c.lock.Lock() - defer c.lock.Unlock() - delete(c.m, h) -} - -// ReapStrangeBlocks drops blocks from cache that don't fit into the -// blkHeight-headHeight interval. Cache should only contain blocks that we -// expect to get and store. -func (c *Cache) ReapStrangeBlocks(blkHeight, headHeight uint32) { - c.lock.Lock() - defer c.lock.Unlock() - for i, b := range c.m { - block, ok := b.(*Block) - if ok && (block.Index < blkHeight || block.Index > headHeight) { - delete(c.m, i) - } - } -} diff --git a/pkg/core/storage/boltdb_store.go b/pkg/core/storage/boltdb_store.go index ffa1d4ecb..18283bc63 100644 --- a/pkg/core/storage/boltdb_store.go +++ b/pkg/core/storage/boltdb_store.go @@ -24,25 +24,6 @@ type BoltDBStore struct { db *bbolt.DB } -// BoltDBBatch simple batch implementation to satisfy the Store interface. -type BoltDBBatch struct { - mem map[*[]byte][]byte -} - -// Len implements the Batch interface. -func (b *BoltDBBatch) Len() int { - return len(b.mem) -} - -// Put implements the Batch interface. -func (b *BoltDBBatch) Put(k, v []byte) { - vcopy := make([]byte, len(v)) - copy(vcopy, v) - kcopy := make([]byte, len(k)) - copy(kcopy, k) - b.mem[&kcopy] = vcopy -} - // NewBoltDBStore returns a new ready to use BoltDB storage with created bucket. func NewBoltDBStore(cfg BoltDBOptions) (*BoltDBStore, error) { var opts *bbolt.Options // should be exposed via BoltDBOptions if anything needed @@ -94,7 +75,7 @@ func (s *BoltDBStore) Get(key []byte) (val []byte, err error) { func (s *BoltDBStore) PutBatch(batch Batch) error { return s.db.Batch(func(tx *bbolt.Tx) error { b := tx.Bucket(Bucket) - for k, v := range batch.(*BoltDBBatch).mem { + for k, v := range batch.(*MemoryBatch).m { err := b.Put(*k, v) if err != nil { return err @@ -122,9 +103,7 @@ func (s *BoltDBStore) Seek(key []byte, f func(k, v []byte)) { // Batch implements the Batch interface and returns a boltdb // compatible Batch. func (s *BoltDBStore) Batch() Batch { - return &BoltDBBatch{ - mem: make(map[*[]byte][]byte), - } + return newMemoryBatch() } // Close releases all db resources. diff --git a/pkg/core/storage/boltdb_store_test.go b/pkg/core/storage/boltdb_store_test.go index ac23d136a..e5cf397b1 100644 --- a/pkg/core/storage/boltdb_store_test.go +++ b/pkg/core/storage/boltdb_store_test.go @@ -12,14 +12,14 @@ import ( func TestBoltDBBatch(t *testing.T) { boltDB := BoltDBStore{} - want := &BoltDBBatch{mem: map[*[]byte][]byte{}} + want := &MemoryBatch{m: map[*[]byte][]byte{}} if got := boltDB.Batch(); !reflect.DeepEqual(got, want) { t.Errorf("BoltDB Batch() = %v, want %v", got, want) } } func TestBoltDBBatch_Len(t *testing.T) { - batch := &BoltDBBatch{mem: map[*[]byte][]byte{}} + batch := &MemoryBatch{m: map[*[]byte][]byte{}} want := len(map[*[]byte][]byte{}) assert.Equal(t, want, batch.Len()) } diff --git a/pkg/core/storage/leveldb_store.go b/pkg/core/storage/leveldb_store.go index 1bc6e013b..f024c41d3 100644 --- a/pkg/core/storage/leveldb_store.go +++ b/pkg/core/storage/leveldb_store.go @@ -41,7 +41,11 @@ func (s *LevelDBStore) Put(key, value []byte) error { // Get implements the Store interface. func (s *LevelDBStore) Get(key []byte) ([]byte, error) { - return s.db.Get(key, nil) + value, err := s.db.Get(key, nil) + if err == leveldb.ErrNotFound { + err = ErrKeyNotFound + } + return value, err } // PutBatch implements the Store interface. diff --git a/pkg/core/storage/memory_store.go b/pkg/core/storage/memory_store.go index 4c9d2f33f..789dffdbb 100644 --- a/pkg/core/storage/memory_store.go +++ b/pkg/core/storage/memory_store.go @@ -9,7 +9,7 @@ import ( // MemoryStore is an in-memory implementation of a Store, mainly // used for testing. Do not use MemoryStore in production. type MemoryStore struct { - *sync.RWMutex + mut sync.RWMutex mem map[string][]byte } @@ -20,8 +20,11 @@ type MemoryBatch struct { // Put implements the Batch interface. func (b *MemoryBatch) Put(k, v []byte) { - key := &k - b.m[key] = v + vcopy := make([]byte, len(v)) + copy(vcopy, v) + kcopy := make([]byte, len(k)) + copy(kcopy, k) + b.m[&kcopy] = vcopy } // Len implements the Batch interface. @@ -32,36 +35,33 @@ func (b *MemoryBatch) Len() int { // NewMemoryStore creates a new MemoryStore object. func NewMemoryStore() *MemoryStore { return &MemoryStore{ - RWMutex: new(sync.RWMutex), - mem: make(map[string][]byte), + mem: make(map[string][]byte), } } // Get implements the Store interface. func (s *MemoryStore) Get(key []byte) ([]byte, error) { - s.RLock() - defer s.RUnlock() + s.mut.RLock() + defer s.mut.RUnlock() if val, ok := s.mem[makeKey(key)]; ok { return val, nil } return nil, ErrKeyNotFound } -// Put implements the Store interface. +// Put implements the Store interface. Never returns an error. func (s *MemoryStore) Put(key, value []byte) error { - s.Lock() + s.mut.Lock() s.mem[makeKey(key)] = value - s.Unlock() + s.mut.Unlock() return nil } -// PutBatch implements the Store interface. +// PutBatch implements the Store interface. Never returns an error. func (s *MemoryStore) PutBatch(batch Batch) error { b := batch.(*MemoryBatch) for k, v := range b.m { - if err := s.Put(*k, v); err != nil { - return err - } + _ = s.Put(*k, v) } return nil } @@ -78,16 +78,44 @@ func (s *MemoryStore) Seek(key []byte, f func(k, v []byte)) { // Batch implements the Batch interface and returns a compatible Batch. func (s *MemoryStore) Batch() Batch { + return newMemoryBatch() +} + +// newMemoryBatch returns new memory batch. +func newMemoryBatch() *MemoryBatch { return &MemoryBatch{ m: make(map[*[]byte][]byte), } } -// Close implements Store interface and clears up memory. +// Persist flushes all the MemoryStore contents into the (supposedly) persistent +// store provided via parameter. +func (s *MemoryStore) Persist(ps Store) (int, error) { + s.mut.Lock() + defer s.mut.Unlock() + batch := ps.Batch() + keys := 0 + for k, v := range s.mem { + kb, _ := hex.DecodeString(k) + batch.Put(kb, v) + keys++ + } + var err error + if keys != 0 { + err = ps.PutBatch(batch) + } + if err == nil { + s.mem = make(map[string][]byte) + } + return keys, err +} + +// Close implements Store interface and clears up memory. Never returns an +// error. func (s *MemoryStore) Close() error { - s.Lock() + s.mut.Lock() s.mem = nil - s.Unlock() + s.mut.Unlock() return nil } diff --git a/pkg/core/storage/memory_store_test.go b/pkg/core/storage/memory_store_test.go index 1ad760e3e..30aa526fd 100644 --- a/pkg/core/storage/memory_store_test.go +++ b/pkg/core/storage/memory_store_test.go @@ -75,3 +75,54 @@ func TestMemoryStore_Seek(t *testing.T) { assert.Equal(t, value, v) }) } + +func TestMemoryStorePersist(t *testing.T) { + // temporary Store + ts := NewMemoryStore() + // persistent Store + ps := NewMemoryStore() + // persisting nothing should do nothing + c, err := ts.Persist(ps) + assert.Equal(t, nil, err) + assert.Equal(t, 0, c) + // persisting one key should result in one key in ps and nothing in ts + assert.NoError(t, ts.Put([]byte("key"), []byte("value"))) + c, err = ts.Persist(ps) + assert.Equal(t, nil, err) + assert.Equal(t, 1, c) + v, err := ps.Get([]byte("key")) + assert.Equal(t, nil, err) + assert.Equal(t, []byte("value"), v) + v, err = ts.Get([]byte("key")) + assert.Equal(t, ErrKeyNotFound, err) + assert.Equal(t, []byte(nil), v) + // now we overwrite the previous `key` contents and also add `key2`, + assert.NoError(t, ts.Put([]byte("key"), []byte("newvalue"))) + assert.NoError(t, ts.Put([]byte("key2"), []byte("value2"))) + // this is to check that now key is written into the ps before we do + // persist + v, err = ps.Get([]byte("key2")) + assert.Equal(t, ErrKeyNotFound, err) + assert.Equal(t, []byte(nil), v) + // two keys should be persisted (one overwritten and one new) and + // available in the ps + c, err = ts.Persist(ps) + assert.Equal(t, nil, err) + assert.Equal(t, 2, c) + v, err = ts.Get([]byte("key")) + assert.Equal(t, ErrKeyNotFound, err) + assert.Equal(t, []byte(nil), v) + v, err = ts.Get([]byte("key2")) + assert.Equal(t, ErrKeyNotFound, err) + assert.Equal(t, []byte(nil), v) + v, err = ps.Get([]byte("key")) + assert.Equal(t, nil, err) + assert.Equal(t, []byte("newvalue"), v) + v, err = ps.Get([]byte("key2")) + assert.Equal(t, nil, err) + assert.Equal(t, []byte("value2"), v) + // we've persisted some values, make sure successive persist is a no-op + c, err = ts.Persist(ps) + assert.Equal(t, nil, err) + assert.Equal(t, 0, c) +} diff --git a/pkg/core/storage/redis_store.go b/pkg/core/storage/redis_store.go index a401ce08d..5dd5ac4e0 100644 --- a/pkg/core/storage/redis_store.go +++ b/pkg/core/storage/redis_store.go @@ -18,28 +18,6 @@ type RedisStore struct { client *redis.Client } -// RedisBatch simple batch implementation to satisfy the Store interface. -type RedisBatch struct { - mem map[string]string -} - -// Len implements the Batch interface. -func (b *RedisBatch) Len() int { - return len(b.mem) -} - -// Put implements the Batch interface. -func (b *RedisBatch) Put(k, v []byte) { - b.mem[string(k)] = string(v) -} - -// NewRedisBatch returns a new ready to use RedisBatch. -func NewRedisBatch() *RedisBatch { - return &RedisBatch{ - mem: make(map[string]string), - } -} - // NewRedisStore returns an new initialized - ready to use RedisStore object. func NewRedisStore(cfg RedisDBOptions) (*RedisStore, error) { c := redis.NewClient(&redis.Options{ @@ -55,13 +33,16 @@ func NewRedisStore(cfg RedisDBOptions) (*RedisStore, error) { // Batch implements the Store interface. func (s *RedisStore) Batch() Batch { - return NewRedisBatch() + return newMemoryBatch() } // Get implements the Store interface. func (s *RedisStore) Get(k []byte) ([]byte, error) { val, err := s.client.Get(string(k)).Result() if err != nil { + if err == redis.Nil { + err = ErrKeyNotFound + } return nil, err } return []byte(val), nil @@ -76,8 +57,8 @@ func (s *RedisStore) Put(k, v []byte) error { // PutBatch implements the Store interface. func (s *RedisStore) PutBatch(b Batch) error { pipe := s.client.Pipeline() - for k, v := range b.(*RedisBatch).mem { - pipe.Set(k, v, 0) + for k, v := range b.(*MemoryBatch).m { + pipe.Set(string(*k), v, 0) } _, err := pipe.Exec() return err diff --git a/pkg/core/storage/redis_store_test.go b/pkg/core/storage/redis_store_test.go index 12403b974..f60f76922 100644 --- a/pkg/core/storage/redis_store_test.go +++ b/pkg/core/storage/redis_store_test.go @@ -1,7 +1,6 @@ package storage import ( - "reflect" "testing" "github.com/alicebob/miniredis" @@ -9,13 +8,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestNewRedisBatch(t *testing.T) { - want := &RedisBatch{mem: map[string]string{}} - if got := NewRedisBatch(); !reflect.DeepEqual(got, want) { - t.Errorf("NewRedisBatch() = %v, want %v", got, want) - } -} - func TestNewRedisStore(t *testing.T) { redisMock, redisStore := prepareRedisMock(t) key := []byte("testKey") @@ -33,50 +25,10 @@ func TestNewRedisStore(t *testing.T) { func TestRedisBatch_Len(t *testing.T) { want := len(map[string]string{}) - b := &RedisBatch{ - mem: map[string]string{}, + b := &MemoryBatch{ + m: map[*[]byte][]byte{}, } - assert.Equal(t, len(b.mem), want) -} - -func TestRedisBatch_Put(t *testing.T) { - type args struct { - k []byte - v []byte - } - tests := []struct { - name string - args args - want *RedisBatch - }{ - {"TestRedisBatch_Put_Strings", - args{ - k: []byte("foo"), - v: []byte("bar"), - }, - &RedisBatch{mem: map[string]string{"foo": "bar"}}, - }, - {"TestRedisBatch_Put_Numbers", - args{ - k: []byte("123"), - v: []byte("456"), - }, - &RedisBatch{mem: map[string]string{"123": "456"}}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actual := &RedisBatch{mem: map[string]string{}} - actual.Put(tt.args.k, tt.args.v) - assert.Equal(t, tt.want, actual) - }) - } -} - -func TestRedisStore_Batch(t *testing.T) { - want := &RedisBatch{mem: map[string]string{}} - actual := NewRedisBatch() - assert.Equal(t, want, actual) + assert.Equal(t, len(b.m), want) } func TestRedisStore_GetAndPut(t *testing.T) { @@ -130,7 +82,7 @@ func TestRedisStore_GetAndPut(t *testing.T) { } func TestRedisStore_PutBatch(t *testing.T) { - batch := &RedisBatch{mem: map[string]string{"foo1": "bar1"}} + batch := &MemoryBatch{m: map[*[]byte][]byte{&[]byte{'f', 'o', 'o', '1'}: []byte("bar1")}} mock, redisStore := prepareRedisMock(t) err := redisStore.PutBatch(batch) assert.Nil(t, err, "Error while PutBatch") diff --git a/pkg/core/unspent_coin_state.go b/pkg/core/unspent_coin_state.go index 9f8026578..1b7aa408c 100644 --- a/pkg/core/unspent_coin_state.go +++ b/pkg/core/unspent_coin_state.go @@ -13,11 +13,36 @@ import ( // coin state. type UnspentCoins map[util.Uint256]*UnspentCoinState -func (u UnspentCoins) getAndUpdate(s storage.Store, hash util.Uint256) (*UnspentCoinState, error) { +// getAndUpdate retreives UnspentCoinState from temporary or persistent Store +// and return it. If it's not present in both stores, returns a new +// UnspentCoinState. +func (u UnspentCoins) getAndUpdate(ts storage.Store, ps storage.Store, hash util.Uint256) (*UnspentCoinState, error) { if unspent, ok := u[hash]; ok { return unspent, nil } + unspent, err := getUnspentCoinStateFromStore(ts, hash) + if err != nil { + if err != storage.ErrKeyNotFound { + return nil, err + } + unspent, err = getUnspentCoinStateFromStore(ps, hash) + if err != nil { + if err != storage.ErrKeyNotFound { + return nil, err + } + unspent = &UnspentCoinState{ + states: []CoinState{}, + } + } + } + + u[hash] = unspent + return unspent, nil +} + +// getUnspentCoinStateFromStore retrieves UnspentCoinState from the given store +func getUnspentCoinStateFromStore(s storage.Store, hash util.Uint256) (*UnspentCoinState, error) { unspent := &UnspentCoinState{} key := storage.AppendPrefix(storage.STCoin, hash.BytesReverse()) if b, err := s.Get(key); err == nil { @@ -27,12 +52,8 @@ func (u UnspentCoins) getAndUpdate(s storage.Store, hash util.Uint256) (*Unspent return nil, fmt.Errorf("failed to decode (UnspentCoinState): %s", r.Err) } } else { - unspent = &UnspentCoinState{ - states: []CoinState{}, - } + return nil, err } - - u[hash] = unspent return unspent, nil } diff --git a/pkg/network/blockqueue.go b/pkg/network/blockqueue.go new file mode 100644 index 000000000..501ddd804 --- /dev/null +++ b/pkg/network/blockqueue.go @@ -0,0 +1,77 @@ +package network + +import ( + "github.com/CityOfZion/neo-go/pkg/core" + "github.com/Workiva/go-datastructures/queue" + log "github.com/sirupsen/logrus" +) + +type blockQueue struct { + queue *queue.PriorityQueue + checkBlocks chan struct{} + chain core.Blockchainer +} + +func newBlockQueue(capacity int, bc core.Blockchainer) *blockQueue { + return &blockQueue{ + queue: queue.NewPriorityQueue(capacity, false), + checkBlocks: make(chan struct{}, 1), + chain: bc, + } +} + +func (bq *blockQueue) run() { + for { + _, ok := <-bq.checkBlocks + if !ok { + break + } + for { + item := bq.queue.Peek() + if item == nil { + break + } + minblock := item.(*core.Block) + if minblock.Index <= bq.chain.BlockHeight()+1 { + _, _ = bq.queue.Get(1) + if minblock.Index == bq.chain.BlockHeight()+1 { + err := bq.chain.AddBlock(minblock) + if err != nil { + log.WithFields(log.Fields{ + "error": err.Error(), + "blockHeight": bq.chain.BlockHeight(), + "nextIndex": minblock.Index, + }).Warn("blockQueue: failed adding block into the blockchain") + } + } + } else { + break + } + } + } +} + +func (bq *blockQueue) putBlock(block *core.Block) error { + if bq.chain.BlockHeight() >= block.Index { + // can easily happen when fetching the same blocks from + // different peers, thus not considered as error + return nil + } + err := bq.queue.Put(block) + select { + case bq.checkBlocks <- struct{}{}: + // ok, signalled to goroutine processing queue + default: + // it's already busy processing blocks + } + return err +} + +func (bq *blockQueue) discard() { + close(bq.checkBlocks) + bq.queue.Dispose() +} + +func (bq *blockQueue) length() int { + return bq.queue.Len() +} diff --git a/pkg/network/blockqueue_test.go b/pkg/network/blockqueue_test.go new file mode 100644 index 000000000..da4124c06 --- /dev/null +++ b/pkg/network/blockqueue_test.go @@ -0,0 +1,71 @@ +package network + +import ( + "testing" + "time" + + "github.com/CityOfZion/neo-go/pkg/core" + "github.com/stretchr/testify/assert" +) + +func TestBlockQueue(t *testing.T) { + chain := &testChain{} + // notice, it's not yet running + bq := newBlockQueue(0, chain) + blocks := make([]*core.Block, 11) + for i := 1; i < 11; i++ { + blocks[i] = &core.Block{BlockBase: core.BlockBase{Index: uint32(i)}} + } + // not the ones expected currently + for i := 3; i < 5; i++ { + assert.NoError(t, bq.putBlock(blocks[i])) + } + // nothing should be put into the blockchain + assert.Equal(t, uint32(0), chain.BlockHeight()) + assert.Equal(t, 2, bq.length()) + // now added expected ones (with duplicates) + for i := 1; i < 5; i++ { + assert.NoError(t, bq.putBlock(blocks[i])) + } + // but they're still not put into the blockchain, because bq isn't running + assert.Equal(t, uint32(0), chain.BlockHeight()) + assert.Equal(t, 4, bq.length()) + go bq.run() + // run() is asynchronous, so we need some kind of timeout anyway and this is the simplest one + for i := 0; i < 5; i++ { + if chain.BlockHeight() != 4 { + time.Sleep(time.Second) + } + } + assert.Equal(t, 0, bq.length()) + assert.Equal(t, uint32(4), chain.BlockHeight()) + // put some old blocks + for i := 1; i < 5; i++ { + assert.NoError(t, bq.putBlock(blocks[i])) + } + assert.Equal(t, 0, bq.length()) + assert.Equal(t, uint32(4), chain.BlockHeight()) + // unexpected blocks with run() active + assert.NoError(t, bq.putBlock(blocks[8])) + assert.Equal(t, 1, bq.length()) + assert.Equal(t, uint32(4), chain.BlockHeight()) + assert.NoError(t, bq.putBlock(blocks[7])) + assert.Equal(t, 2, bq.length()) + assert.Equal(t, uint32(4), chain.BlockHeight()) + // sparse put + assert.NoError(t, bq.putBlock(blocks[10])) + assert.Equal(t, 3, bq.length()) + assert.Equal(t, uint32(4), chain.BlockHeight()) + assert.NoError(t, bq.putBlock(blocks[6])) + assert.NoError(t, bq.putBlock(blocks[5])) + // run() is asynchronous, so we need some kind of timeout anyway and this is the simplest one + for i := 0; i < 5; i++ { + if chain.BlockHeight() != 8 { + time.Sleep(time.Second) + } + } + assert.Equal(t, 1, bq.length()) + assert.Equal(t, uint32(8), chain.BlockHeight()) + bq.discard() + assert.Equal(t, 0, bq.length()) +} diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index bcbdff5df..5580f3873 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -3,6 +3,7 @@ package network import ( "math/rand" "net" + "sync/atomic" "testing" "time" @@ -13,7 +14,9 @@ import ( "github.com/CityOfZion/neo-go/pkg/util" ) -type testChain struct{} +type testChain struct { + blockheight uint32 +} func (chain testChain) GetConfig() config.ProtocolConfiguration { panic("TODO") @@ -38,11 +41,14 @@ func (chain testChain) NetworkFee(t *transaction.Transaction) util.Fixed8 { func (chain testChain) AddHeaders(...*core.Header) error { panic("TODO") } -func (chain testChain) AddBlock(*core.Block) error { - panic("TODO") +func (chain *testChain) AddBlock(block *core.Block) error { + if block.Index == chain.blockheight+1 { + atomic.StoreUint32(&chain.blockheight, block.Index) + } + return nil } -func (chain testChain) BlockHeight() uint32 { - return 0 +func (chain *testChain) BlockHeight() uint32 { + return atomic.LoadUint32(&chain.blockheight) } func (chain testChain) HeaderHeight() uint32 { return 0 @@ -168,7 +174,7 @@ func (p *localPeer) Handshaked() bool { func newTestServer() *Server { return &Server{ ServerConfig: ServerConfig{}, - chain: testChain{}, + chain: &testChain{}, transport: localTransport{}, discovery: testDiscovery{}, id: rand.Uint32(), diff --git a/pkg/network/server.go b/pkg/network/server.go index 1de1b060d..a1caa2fa5 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -45,6 +45,7 @@ type ( transport Transporter discovery Discoverer chain core.Blockchainer + bQueue *blockQueue lock sync.RWMutex peers map[Peer]bool @@ -66,6 +67,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer) *Server { s := &Server{ ServerConfig: config, chain: chain, + bQueue: newBlockQueue(maxBlockBatch, chain), id: rand.Uint32(), quit: make(chan struct{}), addrReq: make(chan *Message, minPeers), @@ -97,6 +99,7 @@ func (s *Server) Start(errChan chan error) { s.discovery.BackFill(s.Seeds...) + go s.bQueue.run() go s.transport.Accept() s.run() } @@ -106,6 +109,7 @@ func (s *Server) Shutdown() { log.WithFields(log.Fields{ "peers": s.PeerCount(), }).Info("shutting down server") + s.bQueue.discard() close(s.quit) } @@ -273,10 +277,7 @@ func (s *Server) handleHeadersCmd(p Peer, headers *payload.Headers) { // handleBlockCmd processes the received block received from its peer. func (s *Server) handleBlockCmd(p Peer, block *core.Block) error { - if !s.chain.HasBlock(block.Hash()) { - return s.chain.AddBlock(block) - } - return nil + return s.bQueue.putBlock(block) } // handleInvCmd will process the received inventory. @@ -329,7 +330,7 @@ func (s *Server) requestBlocks(p Peer) error { hashStart = s.chain.BlockHeight() + 1 headerHeight = s.chain.HeaderHeight() ) - for hashStart < headerHeight && len(hashes) < maxBlockBatch { + for hashStart <= headerHeight && len(hashes) < maxBlockBatch { hash := s.chain.GetHeaderHash(int(hashStart)) hashes = append(hashes, hash) hashStart++ diff --git a/pkg/rpc/server_helper_test.go b/pkg/rpc/server_helper_test.go index 2eb7331ac..fd818750f 100644 --- a/pkg/rpc/server_helper_test.go +++ b/pkg/rpc/server_helper_test.go @@ -167,7 +167,6 @@ func initBlocks(t *testing.T, chain *core.Blockchain) { for i := 0; i < len(blocks); i++ { require.NoError(t, chain.AddBlock(blocks[i])) } - require.NoError(t, chain.Persist(context.Background())) } func makeBlocks(n int) []*core.Block { diff --git a/pkg/rpc/server_test.go b/pkg/rpc/server_test.go index 2db969c7c..3d6c139e1 100644 --- a/pkg/rpc/server_test.go +++ b/pkg/rpc/server_test.go @@ -13,7 +13,6 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestRPC(t *testing.T) { @@ -22,9 +21,6 @@ func TestRPC(t *testing.T) { chain, handler := initServerWithInMemoryChain(ctx, t) - defer func() { - require.NoError(t, chain.Close()) - }() t.Run("getbestblockhash", func(t *testing.T) { rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getbestblockhash", "params": []}` body := doRPCCall(rpc, handler, t)