diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 24f98ffcb..a416cd39e 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -39,9 +39,12 @@ type Blockchain struct { // Current index/height of the highest block. // Read access should always be called by BlockHeight(). - // Write access should only happen in Persist(). + // Write access should only happen in AddBlock(). blockHeight uint32 + // Current persisted block count. + persistedHeight uint32 + // Number of headers stored in the chain file. storedHeaderCount uint32 @@ -112,6 +115,7 @@ func (bc *Blockchain) init() error { return err } bc.blockHeight = bHeight + bc.persistedHeight = bHeight hashes, err := storage.HeaderHashes(bc.Store) if err != nil { @@ -170,7 +174,7 @@ func (bc *Blockchain) Run(ctx context.Context) { bc.headersOpDone <- struct{}{} case <-persistTimer.C: go func() { - err := bc.Persist(ctx) + err := bc.persist(ctx) if err != nil { log.Warnf("failed to persist blockchain: %s", err) } @@ -195,7 +199,13 @@ func (bc *Blockchain) AddBlock(block *Block) error { if bc.verifyBlocks && !block.Verify(false) { return fmt.Errorf("block %s is invalid", block.Hash()) } - return bc.AddHeaders(block.Header()) + err := bc.AddHeaders(block.Header()) + if err != nil { + return err + } + } + if bc.BlockHeight()+1 == block.Index { + atomic.StoreUint32(&bc.blockHeight, block.Index) } return nil } @@ -392,12 +402,12 @@ func (bc *Blockchain) persistBlock(block *Block) error { return err } - atomic.StoreUint32(&bc.blockHeight, block.Index) + bc.persistedHeight = block.Index return nil } -//Persist starts persist loop. -func (bc *Blockchain) Persist(ctx context.Context) (err error) { +// persist flushed current block cache to the persistent storage. +func (bc *Blockchain) persist(ctx context.Context) (err error) { var ( start = time.Now() persisted = 0 @@ -413,7 +423,7 @@ func (bc *Blockchain) Persist(ctx context.Context) (err error) { if uint32(headerList.Len()) <= bc.BlockHeight() { return } - hash := headerList.Get(int(bc.BlockHeight() + 1)) + hash := headerList.Get(int(bc.persistedHeight + 1)) if block, ok := bc.blockCache.GetBlock(hash); ok { if err = bc.persistBlock(block); err != nil { return @@ -465,6 +475,9 @@ func (bc *Blockchain) headerListLen() (n int) { // GetTransaction returns a TX and its height by the given hash. func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) { + if tx, height, ok := bc.blockCache.GetTransaction(hash); ok { + return tx, height, nil + } if tx, ok := bc.memPool.TryGetValue(hash); ok { return tx, 0, nil // the height is not actually defined for memPool transaction. Not sure if zero is a good number in this case. } @@ -490,31 +503,36 @@ func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transactio // GetBlock returns a Block by the given hash. func (bc *Blockchain) GetBlock(hash util.Uint256) (*Block, error) { - key := storage.AppendPrefix(storage.DataBlock, hash.BytesReverse()) - b, err := bc.Get(key) - if err != nil { - return nil, err + block, ok := bc.blockCache.GetBlock(hash) + if !ok { + key := storage.AppendPrefix(storage.DataBlock, hash.BytesReverse()) + b, err := bc.Get(key) + if err != nil { + return nil, err + } + block, err = NewBlockFromTrimmedBytes(b) + if err != nil { + return nil, err + } } - block, err := NewBlockFromTrimmedBytes(b) - if err != nil { - return nil, err + if len(block.Transactions) == 0 { + return nil, fmt.Errorf("only header is available") } - // TODO: persist TX first before we can handle this logic. - // if len(block.Transactions) == 0 { - // return nil, fmt.Errorf("block has no TX") - // } return block, nil } // GetHeader returns data block header identified with the given hash value. func (bc *Blockchain) GetHeader(hash util.Uint256) (*Header, error) { - b, err := bc.Get(storage.AppendPrefix(storage.DataBlock, hash.BytesReverse())) - if err != nil { - return nil, err - } - block, err := NewBlockFromTrimmedBytes(b) - if err != nil { - return nil, err + block, ok := bc.blockCache.GetBlock(hash) + if !ok { + b, err := bc.Get(storage.AppendPrefix(storage.DataBlock, hash.BytesReverse())) + if err != nil { + return nil, err + } + block, err = NewBlockFromTrimmedBytes(b) + if err != nil { + return nil, err + } } return block.Header(), nil } @@ -522,6 +540,9 @@ func (bc *Blockchain) GetHeader(hash util.Uint256) (*Header, error) { // HasTransaction return true if the blockchain contains he given // transaction hash. func (bc *Blockchain) HasTransaction(hash util.Uint256) bool { + if _, _, ok := bc.blockCache.GetTransaction(hash); ok { + return true + } if bc.memPool.ContainsKey(hash) { return true } @@ -536,6 +557,10 @@ func (bc *Blockchain) HasTransaction(hash util.Uint256) bool { // HasBlock return true if the blockchain contains the given // block hash. func (bc *Blockchain) HasBlock(hash util.Uint256) bool { + if bc.blockCache.Has(hash) { + return true + } + if header, err := bc.GetHeader(hash); err == nil { return header.Index <= bc.BlockHeight() } diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 018bb3096..e70c5b66d 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -63,9 +63,8 @@ func TestAddBlock(t *testing.T) { t.Log(bc.blockCache) - if err := bc.Persist(context.Background()); err != nil { - t.Fatal(err) - } + // This one tests persisting blocks, so it does need to persist() + require.NoError(t, bc.persist(context.Background())) for _, block := range blocks { key := storage.AppendPrefix(storage.DataBlock, block.Hash().BytesReverse()) @@ -88,14 +87,18 @@ func TestGetHeader(t *testing.T) { err := bc.AddBlock(block) assert.Nil(t, err) - hash := block.Hash() - header, err := bc.GetHeader(hash) - require.NoError(t, err) - assert.Equal(t, block.Header(), header) + // Test unpersisted and persisted access + for i := 0; i < 2; i++ { + hash := block.Hash() + header, err := bc.GetHeader(hash) + require.NoError(t, err) + assert.Equal(t, block.Header(), header) - block = newBlock(2) - _, err = bc.GetHeader(block.Hash()) - assert.Error(t, err) + b2 := newBlock(2) + _, err = bc.GetHeader(b2.Hash()) + assert.Error(t, err) + assert.NoError(t, bc.persist(context.Background())) + } } func TestGetBlock(t *testing.T) { @@ -111,13 +114,17 @@ func TestGetBlock(t *testing.T) { } } - for i := 0; i < len(blocks); i++ { - block, err := bc.GetBlock(blocks[i].Hash()) - if err != nil { - t.Fatal(err) + // Test unpersisted and persisted access + for j := 0; j < 2; j++ { + for i := 0; i < len(blocks); i++ { + block, err := bc.GetBlock(blocks[i].Hash()) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, blocks[i].Index, block.Index) + assert.Equal(t, blocks[i].Hash(), block.Hash()) } - assert.Equal(t, blocks[i].Index, block.Index) - assert.Equal(t, blocks[i].Hash(), block.Hash()) + assert.NoError(t, bc.persist(context.Background())) } } @@ -133,14 +140,16 @@ func TestHasBlock(t *testing.T) { t.Fatal(err) } } - assert.Nil(t, bc.Persist(context.Background())) - for i := 0; i < len(blocks); i++ { - assert.True(t, bc.HasBlock(blocks[i].Hash())) + // Test unpersisted and persisted access + for j := 0; j < 2; j++ { + for i := 0; i < len(blocks); i++ { + assert.True(t, bc.HasBlock(blocks[i].Hash())) + } + newBlock := newBlock(51) + assert.False(t, bc.HasBlock(newBlock.Hash())) + assert.NoError(t, bc.persist(context.Background())) } - - newBlock := newBlock(51) - assert.False(t, bc.HasBlock(newBlock.Hash())) } func TestGetTransaction(t *testing.T) { @@ -151,19 +160,20 @@ func TestGetTransaction(t *testing.T) { }() assert.Nil(t, bc.AddBlock(block)) - assert.Nil(t, bc.persistBlock(block)) - tx, height, err := bc.GetTransaction(block.Transactions[0].Hash()) - if err != nil { - t.Fatal(err) + // Test unpersisted and persisted access + for j := 0; j < 2; j++ { + tx, height, err := bc.GetTransaction(block.Transactions[0].Hash()) + require.Nil(t, err) + assert.Equal(t, block.Index, height) + assert.Equal(t, block.Transactions[0], tx) + assert.Equal(t, 10, io.GetVarSize(tx)) + assert.Equal(t, 1, io.GetVarSize(tx.Attributes)) + assert.Equal(t, 1, io.GetVarSize(tx.Inputs)) + assert.Equal(t, 1, io.GetVarSize(tx.Outputs)) + assert.Equal(t, 1, io.GetVarSize(tx.Scripts)) + assert.NoError(t, bc.persist(context.Background())) } - assert.Equal(t, block.Index, height) - assert.Equal(t, block.Transactions[0], tx) - assert.Equal(t, 10, io.GetVarSize(tx)) - assert.Equal(t, 1, io.GetVarSize(tx.Attributes)) - assert.Equal(t, 1, io.GetVarSize(tx.Inputs)) - assert.Equal(t, 1, io.GetVarSize(tx.Outputs)) - assert.Equal(t, 1, io.GetVarSize(tx.Scripts)) } func newTestChain(t *testing.T) *Blockchain { diff --git a/pkg/core/cache.go b/pkg/core/cache.go index c2141851e..9208d9be9 100644 --- a/pkg/core/cache.go +++ b/pkg/core/cache.go @@ -3,6 +3,7 @@ package core import ( "sync" + "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/util" ) @@ -13,6 +14,12 @@ type Cache struct { m map[util.Uint256]interface{} } +// txWithHeight is an ugly wrapper to fit the needs of Blockchain's GetTransaction. +type txWithHeight struct { + tx *transaction.Transaction + height uint32 +} + // NewCache returns a ready to use Cache object. func NewCache() *Cache { return &Cache{ @@ -20,6 +27,19 @@ func NewCache() *Cache { } } +// GetTransaction will return a Transaction type from the cache. +func (c *Cache) GetTransaction(h util.Uint256) (*transaction.Transaction, uint32, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + if v, ok := c.m[h]; ok { + txh, ok := v.(txWithHeight) + if ok { + return txh.tx, txh.height, ok + } + } + return nil, 0, false +} + // GetBlock will return a Block type from the cache. func (c *Cache) GetBlock(h util.Uint256) (block *Block, ok bool) { c.lock.RLock() @@ -44,6 +64,12 @@ func (c *Cache) Add(h util.Uint256, v interface{}) { func (c *Cache) add(h util.Uint256, v interface{}) { c.m[h] = v + block, ok := v.(*Block) + if ok { + for _, tx := range block.Transactions { + c.m[tx.Hash()] = txWithHeight{tx, block.Index} + } + } } func (c *Cache) has(h util.Uint256) bool { @@ -69,6 +95,12 @@ func (c *Cache) Len() int { func (c *Cache) Delete(h util.Uint256) { c.lock.Lock() defer c.lock.Unlock() + block, ok := c.m[h].(*Block) + if ok { + for _, tx := range block.Transactions { + delete(c.m, tx.Hash()) + } + } delete(c.m, h) } diff --git a/pkg/rpc/server_helper_test.go b/pkg/rpc/server_helper_test.go index 2eb7331ac..fd818750f 100644 --- a/pkg/rpc/server_helper_test.go +++ b/pkg/rpc/server_helper_test.go @@ -167,7 +167,6 @@ func initBlocks(t *testing.T, chain *core.Blockchain) { for i := 0; i < len(blocks); i++ { require.NoError(t, chain.AddBlock(blocks[i])) } - require.NoError(t, chain.Persist(context.Background())) } func makeBlocks(n int) []*core.Block {