diff --git a/pkg/chain/chaindb.go b/pkg/chain/chaindb.go new file mode 100644 index 000000000..9b16b366a --- /dev/null +++ b/pkg/chain/chaindb.go @@ -0,0 +1,372 @@ +package chain + +import ( + "bufio" + "bytes" + "encoding/binary" + + "github.com/CityOfZion/neo-go/pkg/database" + "github.com/CityOfZion/neo-go/pkg/wire/payload" + "github.com/CityOfZion/neo-go/pkg/wire/payload/transaction" + "github.com/CityOfZion/neo-go/pkg/wire/util" +) + +var ( + // TX is the prefix used when inserting a tx into the db + TX = []byte("TX") + // HEADER is the prefix used when inserting a header into the db + HEADER = []byte("HE") + // LATESTHEADER is the prefix used when inserting the latests header into the db + LATESTHEADER = []byte("LH") + // UTXO is the prefix used when inserting a utxo into the db + UTXO = []byte("UT") + // LATESTBLOCK is the prefix used when inserting the latest block into the db + LATESTBLOCK = []byte("LB") + // BLOCKHASHTX is the prefix used when linking a blockhash to a given tx + BLOCKHASHTX = []byte("BT") + // BLOCKHASHHEIGHT is the prefix used when linking a blockhash to it's height + // This is linked both ways + BLOCKHASHHEIGHT = []byte("BH") + // SCRIPTHASHUTXO is the prefix used when linking a utxo to a scripthash + // This is linked both ways + SCRIPTHASHUTXO = []byte("SU") +) + +// Chaindb is a wrapper around the db interface which adds an extra block chain specific layer on top. +type Chaindb struct { + db database.Database +} + +// This should not be exported for other callers. +// It is safe-guarded by the chain's verification logic +func (c *Chaindb) saveBlock(blk payload.Block, genesis bool) error { + + latestBlockTable := database.NewTable(c.db, LATESTBLOCK) + hashHeightTable := database.NewTable(c.db, BLOCKHASHHEIGHT) + + // Save Txs and link to block hash + err := c.saveTXs(blk.Txs, blk.Hash.Bytes(), genesis) + if err != nil { + return err + } + + // LINK block height to hash - Both ways + // This allows us to fetch a block using it's hash or it's height + // Given the height, we will search the table to get the hash + // We can then fetch all transactions in the tx table, which match that block hash + height := uint32ToBytes(blk.Index) + err = hashHeightTable.Put(height, blk.Hash.Bytes()) + if err != nil { + return err + } + + err = hashHeightTable.Put(blk.Hash.Bytes(), height) + if err != nil { + return err + } + + // Add block as latest block + // This also acts a Commit() for the block. + // If an error occured, then this will be set to the previous block + // This is useful because if the node suddently shut down while saving and the database was not corrupted + // Then the node will see the latestBlock as the last saved block, and re-download the faulty block + // Note: We check for the latest block on startup + return latestBlockTable.Put([]byte(""), blk.Hash.Bytes()) +} + +// Saves a tx and links each tx to the block it was found in +// This should never be exported. Only way to add a tx, is through it's block +func (c *Chaindb) saveTXs(txs []transaction.Transactioner, blockHash []byte, genesis bool) error { + + for txIndex, tx := range txs { + err := c.saveTx(tx, uint32(txIndex), blockHash, genesis) + if err != nil { + return err + } + } + return nil +} + +func (c *Chaindb) saveTx(tx transaction.Transactioner, txIndex uint32, blockHash []byte, genesis bool) error { + + txTable := database.NewTable(c.db, TX) + blockTxTable := database.NewTable(c.db, BLOCKHASHTX) + + // Save the whole tx using it's hash a key + // In order to find a tx in this table, we need to know it's hash + txHash, err := tx.ID() + if err != nil { + return err + } + err = txTable.Put(txHash.Bytes(), tx.Bytes()) + if err != nil { + return err + } + + // LINK TXhash to block + // This allows us to fetch a tx by just knowing what block it was in + // This is useful for when we want to re-construct a block from it's hash + // In order to ge the tx, we must do a prefix search on blockHash + // This will return a set of txHashes. + //We can then use these hashes to search the txtable for the tx's we need + key := bytesConcat(blockHash, uint32ToBytes(txIndex)) + err = blockTxTable.Put(key, txHash.Bytes()) + if err != nil { + return err + } + + // Save all of the utxos in a transaction + // We do this additional save so that we can form a utxo database + // and know when a transaction is a double spend. + utxos := tx.UTXOs() + for utxoIndex, utxo := range utxos { + err := c.saveUTXO(utxo, uint16(utxoIndex), txHash.Bytes(), blockHash) + if err != nil { + return err + } + } + + // Do not check for spent utxos on the genesis block + if genesis { + return nil + } + + // Remove all spent utxos + // We do this so that once an output has been spent + // It will be removed from the utxo database and cannot be spent again + // If the output was never in the utxo database, this function will return an error + txos := tx.TXOs() + for _, txo := range txos { + err := c.removeUTXO(txo) + if err != nil { + return err + } + } + return nil +} + +// saveUTxo will save a utxo and link it to it's transaction and block +func (c *Chaindb) saveUTXO(utxo *transaction.Output, utxoIndex uint16, txHash, blockHash []byte) error { + + utxoTable := database.NewTable(c.db, UTXO) + scripthashUTXOTable := database.NewTable(c.db, SCRIPTHASHUTXO) + + // This is quite messy, we should (if possible) find a way to pass a Writer and Reader interface + // Encode utxo into a buffer + buf := new(bytes.Buffer) + bw := &util.BinWriter{W: buf} + if utxo.Encode(bw); bw.Err != nil { + return bw.Err + } + + // Save UTXO + // In order to find a utxo in the utxoTable + // One must know the txHash that the utxo was in + key := bytesConcat(txHash, uint16ToBytes(utxoIndex)) + if err := utxoTable.Put(key, buf.Bytes()); err != nil { + return err + } + + // LINK utxo to scripthash + // This allows us to find a utxo with the scriptHash + // Since the key starts with scriptHash, we can look for the scriptHash prefix + // and find all utxos for a given scriptHash. + // Additionally, we can search for all utxos for a certain user in a certain block with scriptHash+blockHash + // But this may not be of use to us. However, note that we cannot have just the scriptHash with the utxoIndex + // as this may not be unique. If Kim/Dautt agree, we can change blockHash to blockHeight, which allows us + // To get all utxos above a certain blockHeight. Question is; Would this be useful? + newKey := bytesConcat(utxo.ScriptHash.Bytes(), blockHash, uint16ToBytes(utxoIndex)) + if err := scripthashUTXOTable.Put(newKey, key); err != nil { + return err + } + if err := scripthashUTXOTable.Put(key, newKey); err != nil { + return err + } + return nil +} + +// Remove +func (c *Chaindb) removeUTXO(txo *transaction.Input) error { + + utxoTable := database.NewTable(c.db, UTXO) + scripthashUTXOTable := database.NewTable(c.db, SCRIPTHASHUTXO) + + // Remove spent utxos from utxo database + key := bytesConcat(txo.PrevHash.Bytes(), uint16ToBytes(txo.PrevIndex)) + err := utxoTable.Delete(key) + if err != nil { + return err + } + + // Remove utxos from scripthash table + otherKey, err := scripthashUTXOTable.Get(key) + if err != nil { + return err + } + if err := scripthashUTXOTable.Delete(otherKey); err != nil { + return err + } + if err := scripthashUTXOTable.Delete(key); err != nil { + return err + } + + return nil +} + +// saveHeaders will save a set of headers into the database +func (c *Chaindb) saveHeaders(headers []*payload.BlockBase) error { + + for _, hdr := range headers { + err := c.saveHeader(hdr) + if err != nil { + return err + } + } + return nil +} + +// saveHeader saves a header into the database and updates the latest header +// The headers are saved with their `blockheights` as Key +// If we want to search for a header, we need to know it's index +// Alternatively, we can search the hashHeightTable with the block index to get the hash +// If the block has been saved. +// The reason why headers are saved with their index as Key, is so that we can +// increment the key to find out what block we should fetch next during the initial +// block download, when we are saving thousands of headers +func (c *Chaindb) saveHeader(hdr *payload.BlockBase) error { + + headerTable := database.NewTable(c.db, HEADER) + latestHeaderTable := database.NewTable(c.db, LATESTHEADER) + + index := uint32ToBytes(hdr.Index) + + byt, err := hdr.Bytes() + if err != nil { + return err + } + + err = headerTable.Put(index, byt) + if err != nil { + return err + } + + // Update latest header + return latestHeaderTable.Put([]byte(""), index) +} + +// GetHeaderFromHeight will get a header given it's block height +func (c *Chaindb) GetHeaderFromHeight(index []byte) (*payload.BlockBase, error) { + headerTable := database.NewTable(c.db, HEADER) + hdrBytes, err := headerTable.Get(index) + if err != nil { + return nil, err + } + reader := bytes.NewReader(hdrBytes) + + blockBase := &payload.BlockBase{} + err = blockBase.Decode(reader) + if err != nil { + return nil, err + } + return blockBase, nil +} + +// GetLastHeader will get the header which was saved last in the database +func (c *Chaindb) GetLastHeader() (*payload.BlockBase, error) { + + latestHeaderTable := database.NewTable(c.db, LATESTHEADER) + index, err := latestHeaderTable.Get([]byte("")) + if err != nil { + return nil, err + } + return c.GetHeaderFromHeight(index) +} + +// GetBlockFromHash will return a block given it's hash +func (c *Chaindb) GetBlockFromHash(blockHash []byte) (*payload.Block, error) { + + blockTxTable := database.NewTable(c.db, BLOCKHASHTX) + + // To get a block we need to fetch: + // The transactions (1) + // The header (2) + + // Reconstruct block by fetching it's txs (1) + var txs []transaction.Transactioner + + // Get all Txhashes for this block + txHashes, err := blockTxTable.Prefix(blockHash) + if err != nil { + return nil, err + } + + // Get all Tx's given their hash + txTable := database.NewTable(c.db, TX) + for _, txHash := range txHashes { + + // Fetch tx by it's hash + txBytes, err := txTable.Get(txHash) + if err != nil { + return nil, err + } + reader := bufio.NewReader(bytes.NewReader(txBytes)) + + tx, err := transaction.FromReader(reader) + if err != nil { + return nil, err + } + txs = append(txs, tx) + } + + // Now fetch the header (2) + // We have the block hash, but headers are stored with their `Height` as key. + // We first search the `BlockHashHeight` table to get the height. + //Then we search the headers table with the height + hashHeightTable := database.NewTable(c.db, BLOCKHASHHEIGHT) + height, err := hashHeightTable.Get(blockHash) + if err != nil { + return nil, err + } + hdr, err := c.GetHeaderFromHeight(height) + if err != nil { + return nil, err + } + + // Construct block + block := &payload.Block{ + BlockBase: *hdr, + Txs: txs, + } + return block, nil +} + +// GetLastBlock will return the last block that has been saved +func (c *Chaindb) GetLastBlock() (*payload.Block, error) { + + latestBlockTable := database.NewTable(c.db, LATESTBLOCK) + blockHash, err := latestBlockTable.Get([]byte("")) + if err != nil { + return nil, err + } + return c.GetBlockFromHash(blockHash) +} + +func uint16ToBytes(x uint16) []byte { + index := make([]byte, 2) + binary.BigEndian.PutUint16(index, x) + return index +} + +func uint32ToBytes(x uint32) []byte { + index := make([]byte, 4) + binary.BigEndian.PutUint32(index, x) + return index +} + +func bytesConcat(args ...[]byte) []byte { + var res []byte + for _, arg := range args { + res = append(res, arg...) + } + return res +} diff --git a/pkg/chain/chaindb_test.go b/pkg/chain/chaindb_test.go new file mode 100644 index 000000000..b5656d30b --- /dev/null +++ b/pkg/chain/chaindb_test.go @@ -0,0 +1,201 @@ +package chain + +import ( + "bytes" + "math/rand" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/CityOfZion/neo-go/pkg/database" + "github.com/CityOfZion/neo-go/pkg/wire/payload" + "github.com/CityOfZion/neo-go/pkg/wire/payload/transaction" + "github.com/CityOfZion/neo-go/pkg/wire/util" +) + +var s = rand.NewSource(time.Now().UnixNano()) +var r = rand.New(s) + +func TestLastHeader(t *testing.T) { + _, cdb, hdrs := saveRandomHeaders(t) + + // Select last header from list of headers + lastHeader := hdrs[len(hdrs)-1] + // GetLastHeader from the database + hdr, err := cdb.GetLastHeader() + assert.Nil(t, err) + assert.Equal(t, hdr.Index, lastHeader.Index) + + // Clean up + os.RemoveAll(database.DbDir) +} + +func TestSaveHeader(t *testing.T) { + // save headers then fetch a random element + + db, _, hdrs := saveRandomHeaders(t) + + headerTable := database.NewTable(db, HEADER) + // check that each header was saved + for _, hdr := range hdrs { + index := uint32ToBytes(hdr.Index) + ok, err := headerTable.Has(index) + assert.Nil(t, err) + assert.True(t, ok) + } + + // Clean up + os.RemoveAll(database.DbDir) +} + +func TestSaveBlock(t *testing.T) { + + // Init databases + db, err := database.New("temp.test") + assert.Nil(t, err) + + cdb := &Chaindb{db} + + // Construct block0 and block1 + block0, block1 := twoBlocksLinked(t) + + // Save genesis header + err = cdb.saveHeader(&block0.BlockBase) + assert.Nil(t, err) + + // Save genesis block + err = cdb.saveBlock(block0, true) + assert.Nil(t, err) + + // Test genesis block saved + testBlockWasSaved(t, cdb, block0) + + // Save block1 header + err = cdb.saveHeader(&block1.BlockBase) + assert.Nil(t, err) + + // Save block1 + err = cdb.saveBlock(block1, false) + assert.Nil(t, err) + + // Test block1 was saved + testBlockWasSaved(t, cdb, block1) + + // Clean up + os.RemoveAll(database.DbDir) +} + +func testBlockWasSaved(t *testing.T, cdb *Chaindb, block payload.Block) { + // Fetch last block from database + lastBlock, err := cdb.GetLastBlock() + assert.Nil(t, err) + + // Get byte representation of last block from database + byts, err := lastBlock.Bytes() + assert.Nil(t, err) + + // Get byte representation of block that we saved + blockBytes, err := block.Bytes() + assert.Nil(t, err) + + // Should be equal + assert.True(t, bytes.Equal(byts, blockBytes)) +} + +func randomHeaders(t *testing.T) []*payload.BlockBase { + assert := assert.New(t) + hdrsMsg, err := payload.NewHeadersMessage() + assert.Nil(err) + + for i := 0; i < 2000; i++ { + err = hdrsMsg.AddHeader(randomBlockBase(t)) + assert.Nil(err) + } + + return hdrsMsg.Headers +} + +func randomBlockBase(t *testing.T) *payload.BlockBase { + + base := &payload.BlockBase{ + Version: r.Uint32(), + PrevHash: randUint256(t), + MerkleRoot: randUint256(t), + Timestamp: r.Uint32(), + Index: r.Uint32(), + ConsensusData: r.Uint64(), + NextConsensus: randUint160(t), + Witness: transaction.Witness{ + InvocationScript: []byte{0, 1, 2, 34, 56}, + VerificationScript: []byte{0, 12, 3, 45, 66}, + }, + Hash: randUint256(t), + } + return base +} + +func randomTxs(t *testing.T) []transaction.Transactioner { + + var txs []transaction.Transactioner + for i := 0; i < 10; i++ { + tx := transaction.NewContract(0) + tx.AddInput(transaction.NewInput(randUint256(t), uint16(r.Int()))) + tx.AddOutput(transaction.NewOutput(randUint256(t), r.Int63(), randUint160(t))) + txs = append(txs, tx) + } + return txs +} + +func saveRandomHeaders(t *testing.T) (database.Database, *Chaindb, []*payload.BlockBase) { + db, err := database.New("temp.test") + assert.Nil(t, err) + + cdb := &Chaindb{db} + + hdrs := randomHeaders(t) + + err = cdb.saveHeaders(hdrs) + assert.Nil(t, err) + return db, cdb, hdrs +} + +func randUint256(t *testing.T) util.Uint256 { + slice := make([]byte, 32) + _, err := r.Read(slice) + u, err := util.Uint256DecodeBytes(slice) + assert.Nil(t, err) + return u +} +func randUint160(t *testing.T) util.Uint160 { + slice := make([]byte, 20) + _, err := r.Read(slice) + u, err := util.Uint160DecodeBytes(slice) + assert.Nil(t, err) + return u +} + +// twoBlocksLinked will return two blocks, the second block spends from the utxos in the first +func twoBlocksLinked(t *testing.T) (payload.Block, payload.Block) { + genesisBase := randomBlockBase(t) + genesisTxs := randomTxs(t) + genesisBlock := payload.Block{BlockBase: *genesisBase, Txs: genesisTxs} + + var txs []transaction.Transactioner + + // Form transactions that spend from the genesis block + for _, tx := range genesisTxs { + txHash, err := tx.ID() + assert.Nil(t, err) + newTx := transaction.NewContract(0) + newTx.AddInput(transaction.NewInput(txHash, 0)) + newTx.AddOutput(transaction.NewOutput(randUint256(t), r.Int63(), randUint160(t))) + txs = append(txs, newTx) + } + + nextBase := randomBlockBase(t) + nextBlock := payload.Block{BlockBase: *nextBase, Txs: txs} + + return genesisBlock, nextBlock +} diff --git a/pkg/database/leveldb.go b/pkg/database/leveldb.go index b90ffaf65..a039c6010 100644 --- a/pkg/database/leveldb.go +++ b/pkg/database/leveldb.go @@ -3,14 +3,22 @@ package database import ( "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/errors" + ldbutil "github.com/syndtr/goleveldb/leveldb/util" ) +//DbDir is the folder which all database files will be put under +// Structure /DbDir/net +const DbDir = "db/" + // LDB represents a leveldb object type LDB struct { db *leveldb.DB - path string + Path string } +// ErrNotFound means that the value was not found in the db +var ErrNotFound = errors.New("value not found for that key") + // Database contains all methods needed for an object to be a database type Database interface { // Has checks whether the key is in the database @@ -21,25 +29,30 @@ type Database interface { Get(key []byte) ([]byte, error) // Delete deletes the given value for the key from the database Delete(key []byte) error + //Prefix returns all values that start with key + Prefix(key []byte) ([][]byte, error) // Close closes the underlying db object Close() error } // New will return a new leveldb instance -func New(path string) *LDB { - db, err := leveldb.OpenFile(path, nil) +func New(path string) (*LDB, error) { + dbPath := DbDir + path + db, err := leveldb.OpenFile(dbPath, nil) if err != nil { - return nil + return nil, err } - if _, corrupted := err.(*errors.ErrCorrupted); corrupted { db, err = leveldb.RecoverFile(path, nil) + if err != nil { + return nil, err + } } return &LDB{ db, - path, - } + dbPath, + }, nil } // Has implements the database interface @@ -54,7 +67,15 @@ func (l *LDB) Put(key []byte, value []byte) error { // Get implements the database interface func (l *LDB) Get(key []byte) ([]byte, error) { - return l.db.Get(key, nil) + val, err := l.db.Get(key, nil) + if err == nil { + return val, nil + } + if err == leveldb.ErrNotFound { + return val, ErrNotFound + } + return val, err + } // Delete implements the database interface @@ -66,3 +87,28 @@ func (l *LDB) Delete(key []byte) error { func (l *LDB) Close() error { return l.db.Close() } + +// Prefix implements the database interface +func (l *LDB) Prefix(key []byte) ([][]byte, error) { + + var results [][]byte + + iter := l.db.NewIterator(ldbutil.BytesPrefix(key), nil) + for iter.Next() { + + value := iter.Value() + + // Copy the data, as we cannot modify it + // Once the iter has been released + deref := make([]byte, len(value)) + + copy(deref, value) + + // Append result + results = append(results, deref) + + } + iter.Release() + err := iter.Error() + return results, err +} diff --git a/pkg/database/leveldb_test.go b/pkg/database/leveldb_test.go index 0991682ee..61c831bd3 100644 --- a/pkg/database/leveldb_test.go +++ b/pkg/database/leveldb_test.go @@ -6,27 +6,31 @@ import ( "github.com/CityOfZion/neo-go/pkg/database" "github.com/stretchr/testify/assert" - "github.com/syndtr/goleveldb/leveldb/errors" ) const path = "temp" func cleanup(db *database.LDB) { db.Close() - os.RemoveAll(path) + os.RemoveAll(database.DbDir) } func TestDBCreate(t *testing.T) { - db := database.New(path) + + db, err := database.New(path) + assert.Nil(t, err) + assert.NotEqual(t, nil, db) cleanup(db) } func TestPutGet(t *testing.T) { - db := database.New(path) + + db, err := database.New(path) + assert.Nil(t, err) key := []byte("Hello") value := []byte("World") - err := db.Put(key, value) + err = db.Put(key, value) assert.Equal(t, nil, err) res, err := db.Get(key) @@ -36,25 +40,28 @@ func TestPutGet(t *testing.T) { } func TestPutDelete(t *testing.T) { - db := database.New(path) + db, err := database.New(path) + assert.Nil(t, err) key := []byte("Hello") value := []byte("World") - err := db.Put(key, value) + err = db.Put(key, value) err = db.Delete(key) assert.Equal(t, nil, err) res, err := db.Get(key) - assert.Equal(t, errors.ErrNotFound, err) + assert.Equal(t, database.ErrNotFound, err) assert.Equal(t, res, []byte{}) cleanup(db) } func TestHas(t *testing.T) { - db := database.New("temp") + + db, err := database.New(path) + assert.Nil(t, err) res, err := db.Has([]byte("NotExist")) assert.Equal(t, res, false) @@ -73,8 +80,12 @@ func TestHas(t *testing.T) { } func TestDBClose(t *testing.T) { - db := database.New("temp") - err := db.Close() + + db, err := database.New(path) + assert.Nil(t, err) + + err = db.Close() assert.Equal(t, nil, err) + cleanup(db) } diff --git a/pkg/database/table.go b/pkg/database/table.go index c36b80185..8c4cf3023 100644 --- a/pkg/database/table.go +++ b/pkg/database/table.go @@ -16,29 +16,35 @@ func NewTable(db Database, prefix []byte) *Table { // Has implements the database interface func (t *Table) Has(key []byte) (bool, error) { - key = append(t.prefix, key...) - return t.db.Has(key) + prefixedKey := append(t.prefix, key...) + return t.db.Has(prefixedKey) } // Put implements the database interface func (t *Table) Put(key []byte, value []byte) error { - key = append(t.prefix, key...) - return t.db.Put(key, value) + prefixedKey := append(t.prefix, key...) + return t.db.Put(prefixedKey, value) } // Get implements the database interface func (t *Table) Get(key []byte) ([]byte, error) { - key = append(t.prefix, key...) - return t.db.Get(key) + prefixedKey := append(t.prefix, key...) + return t.db.Get(prefixedKey) } // Delete implements the database interface func (t *Table) Delete(key []byte) error { - key = append(t.prefix, key...) - return t.db.Delete(key) + prefixedKey := append(t.prefix, key...) + return t.db.Delete(prefixedKey) } // Close implements the database interface func (t *Table) Close() error { return nil } + +// Prefix implements the database interface +func (t *Table) Prefix(key []byte) ([][]byte, error) { + prefixedKey := append(t.prefix, key...) + return t.db.Prefix(prefixedKey) +} diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 6a87a4aa6..035dad835 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -195,13 +195,15 @@ func (p *Peer) PingLoop() { /*not implemented in other neo clients*/ } func (p *Peer) Run() error { err := p.Handshake() - + if err != nil { + return err + } go p.StartProtocol() go p.ReadLoop() go p.WriteLoop() //go p.PingLoop() // since it is not implemented. It will disconnect all other impls. - return err + return nil } diff --git a/pkg/wire/protocol/protocol.go b/pkg/wire/protocol/protocol.go index ebc916e50..d8b223267 100644 --- a/pkg/wire/protocol/protocol.go +++ b/pkg/wire/protocol/protocol.go @@ -30,3 +30,15 @@ const ( MainNet Magic = 7630401 TestNet Magic = 0x74746e41 ) + +// String implements the stringer interface +func (m Magic) String() string { + switch m { + case MainNet: + return "Mainnet" + case TestNet: + return "Testnet" + default: + return "UnknownNet" + } +}