Merge branch 'dev' into vm

This commit is contained in:
Roman Khimov 2019-08-12 19:12:19 +03:00 committed by GitHub
commit 2ddf1d15ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
72 changed files with 3560 additions and 731 deletions

View file

@ -9,7 +9,7 @@ install:
- go mod tidy -v - go mod tidy -v
script: script:
- golint -set_exit_status ./... - golint -set_exit_status ./...
- go test -race -coverprofile=coverage.txt -covermode=atomic ./... - go test -v -race -coverprofile=coverage.txt -covermode=atomic ./...
after_success: after_success:
- bash <(curl -s https://codecov.io/bash) - bash <(curl -s https://codecov.io/bash)
matrix: matrix:

20
main.go Normal file
View file

@ -0,0 +1,20 @@
package main
import (
"fmt"
"github.com/CityOfZion/neo-go/pkg/server"
"github.com/CityOfZion/neo-go/pkg/wire/protocol"
)
func main() {
s, err := server.New(protocol.MainNet, 10332)
if err != nil {
fmt.Println(err)
return
}
err = s.Run()
if err != nil {
fmt.Println("Server has stopped from the following error: ", err.Error())
}
}

137
pkg/chain/chain.go Normal file
View file

@ -0,0 +1,137 @@
package chain
import (
"fmt"
"github.com/pkg/errors"
"github.com/CityOfZion/neo-go/pkg/chaincfg"
"github.com/CityOfZion/neo-go/pkg/wire/payload/transaction"
"github.com/CityOfZion/neo-go/pkg/wire/protocol"
"github.com/CityOfZion/neo-go/pkg/database"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
)
var (
// ErrBlockAlreadyExists happens when you try to save the same block twice
ErrBlockAlreadyExists = errors.New("this block has already been saved in the database")
// ErrFutureBlock happens when you try to save a block that is not the next block sequentially
ErrFutureBlock = errors.New("this is not the next block sequentially, that should be added to the chain")
)
// Chain represents a blockchain instance
type Chain struct {
Db *Chaindb
height uint32
}
// New returns a new chain instance
func New(db database.Database, magic protocol.Magic) (*Chain, error) {
chain := &Chain{
Db: &Chaindb{db},
}
// Get last header saved to see if this is a fresh database
_, err := chain.Db.GetLastHeader()
if err == nil {
return chain, nil
}
if err != database.ErrNotFound {
return nil, err
}
// We have a database.ErrNotFound. Insert the genesisBlock
fmt.Printf("Starting a fresh database for %s\n", magic.String())
params, err := chaincfg.NetParams(magic)
if err != nil {
return nil, err
}
err = chain.Db.saveHeader(&params.GenesisBlock.BlockBase)
if err != nil {
return nil, err
}
err = chain.Db.saveBlock(params.GenesisBlock, true)
if err != nil {
return nil, err
}
return chain, nil
}
// ProcessBlock verifies and saves the block in the database
// XXX: for now we will just save without verifying the block
// This function is called by the server and if an error is returned then
// the server informs the sync manager to redownload the block
// XXX:We should also check if the header is already saved in the database
// If not, then we need to validate the header with the rest of the chain
// For now we re-save the header
func (c *Chain) ProcessBlock(block payload.Block) error {
// Check if we already have this block saved
// XXX: We can optimise by implementing a Has() method
// caching the last block in memory
lastBlock, err := c.Db.GetLastBlock()
if err != nil {
return err
}
if lastBlock.Index > block.Index {
return ErrBlockAlreadyExists
}
if block.Index > lastBlock.Index+1 {
return ErrFutureBlock
}
err = c.verifyBlock(block)
if err != nil {
return ValidationError{err.Error()}
}
err = c.Db.saveBlock(block, false)
if err != nil {
return DatabaseError{err.Error()}
}
return nil
}
// VerifyBlock verifies whether a block is valid according
// to the rules of consensus
func (c *Chain) verifyBlock(block payload.Block) error {
return nil
}
// VerifyTx verifies whether a transaction is valid according
// to the rules of consensus
func (c *Chain) VerifyTx(tx transaction.Transactioner) error {
return nil
}
// ProcessHeaders will save the set of headers without validating
func (c *Chain) ProcessHeaders(hdrs []*payload.BlockBase) error {
err := c.verifyHeaders(hdrs)
if err != nil {
return ValidationError{err.Error()}
}
err = c.Db.saveHeaders(hdrs)
if err != nil {
return DatabaseError{err.Error()}
}
return nil
}
// verifyHeaders will be used to verify a batch of headers
// should only ever be called during the initial block download
// or when the node receives a HeadersMessage
func (c *Chain) verifyHeaders(hdrs []*payload.BlockBase) error {
return nil
}
// CurrentHeight returns the index of the block
// at the tip of the chain
func (c Chain) CurrentHeight() uint32 {
return c.height
}

372
pkg/chain/chaindb.go Normal file
View file

@ -0,0 +1,372 @@
package chain
import (
"bufio"
"bytes"
"encoding/binary"
"github.com/CityOfZion/neo-go/pkg/database"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/CityOfZion/neo-go/pkg/wire/payload/transaction"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
var (
// TX is the prefix used when inserting a tx into the db
TX = []byte("TX")
// HEADER is the prefix used when inserting a header into the db
HEADER = []byte("HE")
// LATESTHEADER is the prefix used when inserting the latests header into the db
LATESTHEADER = []byte("LH")
// UTXO is the prefix used when inserting a utxo into the db
UTXO = []byte("UT")
// LATESTBLOCK is the prefix used when inserting the latest block into the db
LATESTBLOCK = []byte("LB")
// BLOCKHASHTX is the prefix used when linking a blockhash to a given tx
BLOCKHASHTX = []byte("BT")
// BLOCKHASHHEIGHT is the prefix used when linking a blockhash to it's height
// This is linked both ways
BLOCKHASHHEIGHT = []byte("BH")
// SCRIPTHASHUTXO is the prefix used when linking a utxo to a scripthash
// This is linked both ways
SCRIPTHASHUTXO = []byte("SU")
)
// Chaindb is a wrapper around the db interface which adds an extra block chain specific layer on top.
type Chaindb struct {
db database.Database
}
// This should not be exported for other callers.
// It is safe-guarded by the chain's verification logic
func (c *Chaindb) saveBlock(blk payload.Block, genesis bool) error {
latestBlockTable := database.NewTable(c.db, LATESTBLOCK)
hashHeightTable := database.NewTable(c.db, BLOCKHASHHEIGHT)
// Save Txs and link to block hash
err := c.saveTXs(blk.Txs, blk.Hash.Bytes(), genesis)
if err != nil {
return err
}
// LINK block height to hash - Both ways
// This allows us to fetch a block using it's hash or it's height
// Given the height, we will search the table to get the hash
// We can then fetch all transactions in the tx table, which match that block hash
height := uint32ToBytes(blk.Index)
err = hashHeightTable.Put(height, blk.Hash.Bytes())
if err != nil {
return err
}
err = hashHeightTable.Put(blk.Hash.Bytes(), height)
if err != nil {
return err
}
// Add block as latest block
// This also acts a Commit() for the block.
// If an error occured, then this will be set to the previous block
// This is useful because if the node suddently shut down while saving and the database was not corrupted
// Then the node will see the latestBlock as the last saved block, and re-download the faulty block
// Note: We check for the latest block on startup
return latestBlockTable.Put([]byte(""), blk.Hash.Bytes())
}
// Saves a tx and links each tx to the block it was found in
// This should never be exported. Only way to add a tx, is through it's block
func (c *Chaindb) saveTXs(txs []transaction.Transactioner, blockHash []byte, genesis bool) error {
for txIndex, tx := range txs {
err := c.saveTx(tx, uint32(txIndex), blockHash, genesis)
if err != nil {
return err
}
}
return nil
}
func (c *Chaindb) saveTx(tx transaction.Transactioner, txIndex uint32, blockHash []byte, genesis bool) error {
txTable := database.NewTable(c.db, TX)
blockTxTable := database.NewTable(c.db, BLOCKHASHTX)
// Save the whole tx using it's hash a key
// In order to find a tx in this table, we need to know it's hash
txHash, err := tx.ID()
if err != nil {
return err
}
err = txTable.Put(txHash.Bytes(), tx.BaseTx().Bytes())
if err != nil {
return err
}
// LINK TXhash to block
// This allows us to fetch a tx by just knowing what block it was in
// This is useful for when we want to re-construct a block from it's hash
// In order to ge the tx, we must do a prefix search on blockHash
// This will return a set of txHashes.
//We can then use these hashes to search the txtable for the tx's we need
key := bytesConcat(blockHash, uint32ToBytes(txIndex))
err = blockTxTable.Put(key, txHash.Bytes())
if err != nil {
return err
}
// Save all of the utxos in a transaction
// We do this additional save so that we can form a utxo database
// and know when a transaction is a double spend.
utxos := tx.BaseTx().Outputs
for utxoIndex, utxo := range utxos {
err := c.saveUTXO(utxo, uint16(utxoIndex), txHash.Bytes(), blockHash)
if err != nil {
return err
}
}
// Do not check for spent utxos on the genesis block
if genesis {
return nil
}
// Remove all spent utxos
// We do this so that once an output has been spent
// It will be removed from the utxo database and cannot be spent again
// If the output was never in the utxo database, this function will return an error
txos := tx.BaseTx().Inputs
for _, txo := range txos {
err := c.removeUTXO(txo)
if err != nil {
return err
}
}
return nil
}
// saveUTxo will save a utxo and link it to it's transaction and block
func (c *Chaindb) saveUTXO(utxo *transaction.Output, utxoIndex uint16, txHash, blockHash []byte) error {
utxoTable := database.NewTable(c.db, UTXO)
scripthashUTXOTable := database.NewTable(c.db, SCRIPTHASHUTXO)
// This is quite messy, we should (if possible) find a way to pass a Writer and Reader interface
// Encode utxo into a buffer
buf := new(bytes.Buffer)
bw := &util.BinWriter{W: buf}
if utxo.Encode(bw); bw.Err != nil {
return bw.Err
}
// Save UTXO
// In order to find a utxo in the utxoTable
// One must know the txHash that the utxo was in
key := bytesConcat(txHash, uint16ToBytes(utxoIndex))
if err := utxoTable.Put(key, buf.Bytes()); err != nil {
return err
}
// LINK utxo to scripthash
// This allows us to find a utxo with the scriptHash
// Since the key starts with scriptHash, we can look for the scriptHash prefix
// and find all utxos for a given scriptHash.
// Additionally, we can search for all utxos for a certain user in a certain block with scriptHash+blockHash
// But this may not be of use to us. However, note that we cannot have just the scriptHash with the utxoIndex
// as this may not be unique. If Kim/Dautt agree, we can change blockHash to blockHeight, which allows us
// To get all utxos above a certain blockHeight. Question is; Would this be useful?
newKey := bytesConcat(utxo.ScriptHash.Bytes(), blockHash, uint16ToBytes(utxoIndex))
if err := scripthashUTXOTable.Put(newKey, key); err != nil {
return err
}
if err := scripthashUTXOTable.Put(key, newKey); err != nil {
return err
}
return nil
}
// Remove
func (c *Chaindb) removeUTXO(txo *transaction.Input) error {
utxoTable := database.NewTable(c.db, UTXO)
scripthashUTXOTable := database.NewTable(c.db, SCRIPTHASHUTXO)
// Remove spent utxos from utxo database
key := bytesConcat(txo.PrevHash.Bytes(), uint16ToBytes(txo.PrevIndex))
err := utxoTable.Delete(key)
if err != nil {
return err
}
// Remove utxos from scripthash table
otherKey, err := scripthashUTXOTable.Get(key)
if err != nil {
return err
}
if err := scripthashUTXOTable.Delete(otherKey); err != nil {
return err
}
if err := scripthashUTXOTable.Delete(key); err != nil {
return err
}
return nil
}
// saveHeaders will save a set of headers into the database
func (c *Chaindb) saveHeaders(headers []*payload.BlockBase) error {
for _, hdr := range headers {
err := c.saveHeader(hdr)
if err != nil {
return err
}
}
return nil
}
// saveHeader saves a header into the database and updates the latest header
// The headers are saved with their `blockheights` as Key
// If we want to search for a header, we need to know it's index
// Alternatively, we can search the hashHeightTable with the block index to get the hash
// If the block has been saved.
// The reason why headers are saved with their index as Key, is so that we can
// increment the key to find out what block we should fetch next during the initial
// block download, when we are saving thousands of headers
func (c *Chaindb) saveHeader(hdr *payload.BlockBase) error {
headerTable := database.NewTable(c.db, HEADER)
latestHeaderTable := database.NewTable(c.db, LATESTHEADER)
index := uint32ToBytes(hdr.Index)
byt, err := hdr.Bytes()
if err != nil {
return err
}
err = headerTable.Put(index, byt)
if err != nil {
return err
}
// Update latest header
return latestHeaderTable.Put([]byte(""), index)
}
// GetHeaderFromHeight will get a header given it's block height
func (c *Chaindb) GetHeaderFromHeight(index []byte) (*payload.BlockBase, error) {
headerTable := database.NewTable(c.db, HEADER)
hdrBytes, err := headerTable.Get(index)
if err != nil {
return nil, err
}
reader := bytes.NewReader(hdrBytes)
blockBase := &payload.BlockBase{}
err = blockBase.Decode(reader)
if err != nil {
return nil, err
}
return blockBase, nil
}
// GetLastHeader will get the header which was saved last in the database
func (c *Chaindb) GetLastHeader() (*payload.BlockBase, error) {
latestHeaderTable := database.NewTable(c.db, LATESTHEADER)
index, err := latestHeaderTable.Get([]byte(""))
if err != nil {
return nil, err
}
return c.GetHeaderFromHeight(index)
}
// GetBlockFromHash will return a block given it's hash
func (c *Chaindb) GetBlockFromHash(blockHash []byte) (*payload.Block, error) {
blockTxTable := database.NewTable(c.db, BLOCKHASHTX)
// To get a block we need to fetch:
// The transactions (1)
// The header (2)
// Reconstruct block by fetching it's txs (1)
var txs []transaction.Transactioner
// Get all Txhashes for this block
txHashes, err := blockTxTable.Prefix(blockHash)
if err != nil {
return nil, err
}
// Get all Tx's given their hash
txTable := database.NewTable(c.db, TX)
for _, txHash := range txHashes {
// Fetch tx by it's hash
txBytes, err := txTable.Get(txHash)
if err != nil {
return nil, err
}
reader := bufio.NewReader(bytes.NewReader(txBytes))
tx, err := transaction.FromReader(reader)
if err != nil {
return nil, err
}
txs = append(txs, tx)
}
// Now fetch the header (2)
// We have the block hash, but headers are stored with their `Height` as key.
// We first search the `BlockHashHeight` table to get the height.
//Then we search the headers table with the height
hashHeightTable := database.NewTable(c.db, BLOCKHASHHEIGHT)
height, err := hashHeightTable.Get(blockHash)
if err != nil {
return nil, err
}
hdr, err := c.GetHeaderFromHeight(height)
if err != nil {
return nil, err
}
// Construct block
block := &payload.Block{
BlockBase: *hdr,
Txs: txs,
}
return block, nil
}
// GetLastBlock will return the last block that has been saved
func (c *Chaindb) GetLastBlock() (*payload.Block, error) {
latestBlockTable := database.NewTable(c.db, LATESTBLOCK)
blockHash, err := latestBlockTable.Get([]byte(""))
if err != nil {
return nil, err
}
return c.GetBlockFromHash(blockHash)
}
func uint16ToBytes(x uint16) []byte {
index := make([]byte, 2)
binary.BigEndian.PutUint16(index, x)
return index
}
func uint32ToBytes(x uint32) []byte {
index := make([]byte, 4)
binary.BigEndian.PutUint32(index, x)
return index
}
func bytesConcat(args ...[]byte) []byte {
var res []byte
for _, arg := range args {
res = append(res, arg...)
}
return res
}

201
pkg/chain/chaindb_test.go Normal file
View file

@ -0,0 +1,201 @@
package chain
import (
"bytes"
"math/rand"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/CityOfZion/neo-go/pkg/database"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/CityOfZion/neo-go/pkg/wire/payload/transaction"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
var s = rand.NewSource(time.Now().UnixNano())
var r = rand.New(s)
func TestLastHeader(t *testing.T) {
_, cdb, hdrs := saveRandomHeaders(t)
// Select last header from list of headers
lastHeader := hdrs[len(hdrs)-1]
// GetLastHeader from the database
hdr, err := cdb.GetLastHeader()
assert.Nil(t, err)
assert.Equal(t, hdr.Index, lastHeader.Index)
// Clean up
os.RemoveAll(database.DbDir)
}
func TestSaveHeader(t *testing.T) {
// save headers then fetch a random element
db, _, hdrs := saveRandomHeaders(t)
headerTable := database.NewTable(db, HEADER)
// check that each header was saved
for _, hdr := range hdrs {
index := uint32ToBytes(hdr.Index)
ok, err := headerTable.Has(index)
assert.Nil(t, err)
assert.True(t, ok)
}
// Clean up
os.RemoveAll(database.DbDir)
}
func TestSaveBlock(t *testing.T) {
// Init databases
db, err := database.New("temp.test")
assert.Nil(t, err)
cdb := &Chaindb{db}
// Construct block0 and block1
block0, block1 := twoBlocksLinked(t)
// Save genesis header
err = cdb.saveHeader(&block0.BlockBase)
assert.Nil(t, err)
// Save genesis block
err = cdb.saveBlock(block0, true)
assert.Nil(t, err)
// Test genesis block saved
testBlockWasSaved(t, cdb, block0)
// Save block1 header
err = cdb.saveHeader(&block1.BlockBase)
assert.Nil(t, err)
// Save block1
err = cdb.saveBlock(block1, false)
assert.Nil(t, err)
// Test block1 was saved
testBlockWasSaved(t, cdb, block1)
// Clean up
os.RemoveAll(database.DbDir)
}
func testBlockWasSaved(t *testing.T, cdb *Chaindb, block payload.Block) {
// Fetch last block from database
lastBlock, err := cdb.GetLastBlock()
assert.Nil(t, err)
// Get byte representation of last block from database
byts, err := lastBlock.Bytes()
assert.Nil(t, err)
// Get byte representation of block that we saved
blockBytes, err := block.Bytes()
assert.Nil(t, err)
// Should be equal
assert.True(t, bytes.Equal(byts, blockBytes))
}
func randomHeaders(t *testing.T) []*payload.BlockBase {
assert := assert.New(t)
hdrsMsg, err := payload.NewHeadersMessage()
assert.Nil(err)
for i := 0; i < 2000; i++ {
err = hdrsMsg.AddHeader(randomBlockBase(t))
assert.Nil(err)
}
return hdrsMsg.Headers
}
func randomBlockBase(t *testing.T) *payload.BlockBase {
base := &payload.BlockBase{
Version: r.Uint32(),
PrevHash: randUint256(t),
MerkleRoot: randUint256(t),
Timestamp: r.Uint32(),
Index: r.Uint32(),
ConsensusData: r.Uint64(),
NextConsensus: randUint160(t),
Witness: transaction.Witness{
InvocationScript: []byte{0, 1, 2, 34, 56},
VerificationScript: []byte{0, 12, 3, 45, 66},
},
Hash: randUint256(t),
}
return base
}
func randomTxs(t *testing.T) []transaction.Transactioner {
var txs []transaction.Transactioner
for i := 0; i < 10; i++ {
tx := transaction.NewContract(0)
tx.AddInput(transaction.NewInput(randUint256(t), uint16(r.Int())))
tx.AddOutput(transaction.NewOutput(randUint256(t), r.Int63(), randUint160(t)))
txs = append(txs, tx)
}
return txs
}
func saveRandomHeaders(t *testing.T) (database.Database, *Chaindb, []*payload.BlockBase) {
db, err := database.New("temp.test")
assert.Nil(t, err)
cdb := &Chaindb{db}
hdrs := randomHeaders(t)
err = cdb.saveHeaders(hdrs)
assert.Nil(t, err)
return db, cdb, hdrs
}
func randUint256(t *testing.T) util.Uint256 {
slice := make([]byte, 32)
_, err := r.Read(slice)
u, err := util.Uint256DecodeBytes(slice)
assert.Nil(t, err)
return u
}
func randUint160(t *testing.T) util.Uint160 {
slice := make([]byte, 20)
_, err := r.Read(slice)
u, err := util.Uint160DecodeBytes(slice)
assert.Nil(t, err)
return u
}
// twoBlocksLinked will return two blocks, the second block spends from the utxos in the first
func twoBlocksLinked(t *testing.T) (payload.Block, payload.Block) {
genesisBase := randomBlockBase(t)
genesisTxs := randomTxs(t)
genesisBlock := payload.Block{BlockBase: *genesisBase, Txs: genesisTxs}
var txs []transaction.Transactioner
// Form transactions that spend from the genesis block
for _, tx := range genesisTxs {
txHash, err := tx.ID()
assert.Nil(t, err)
newTx := transaction.NewContract(0)
newTx.AddInput(transaction.NewInput(txHash, 0))
newTx.AddOutput(transaction.NewOutput(randUint256(t), r.Int63(), randUint160(t)))
txs = append(txs, newTx)
}
nextBase := randomBlockBase(t)
nextBlock := payload.Block{BlockBase: *nextBase, Txs: txs}
return genesisBlock, nextBlock
}

19
pkg/chain/errors.go Normal file
View file

@ -0,0 +1,19 @@
package chain
// ValidationError occurs when verificatio of the object fails
type ValidationError struct {
msg string
}
func (v ValidationError) Error() string {
return v.msg
}
// DatabaseError occurs when the chain fails to save the object in the database
type DatabaseError struct {
msg string
}
func (d DatabaseError) Error() string {
return d.msg
}

44
pkg/chaincfg/chaincfg.go Normal file
View file

@ -0,0 +1,44 @@
package chaincfg
import (
"bytes"
"encoding/hex"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/CityOfZion/neo-go/pkg/wire/protocol"
)
// Params are the parameters needed to setup the network
type Params struct {
GenesisBlock payload.Block
}
//NetParams returns the parameters for the chosen network magic
func NetParams(magic protocol.Magic) (Params, error) {
switch magic {
case protocol.MainNet:
return mainnet()
default:
return mainnet()
}
}
//Mainnet returns the parameters needed for mainnet
func mainnet() (Params, error) {
rawHex := "000000000000000000000000000000000000000000000000000000000000000000000000f41bc036e39b0d6b0579c851c6fde83af802fa4e57bec0bc3365eae3abf43f8065fc8857000000001dac2b7c0000000059e75d652b5d3827bf04c165bbe9ef95cca4bf55010001510400001dac2b7c00000000400000455b7b226c616e67223a227a682d434e222c226e616d65223a22e5b08fe89a81e882a1227d2c7b226c616e67223a22656e222c226e616d65223a22416e745368617265227d5d0000c16ff28623000000da1745e9b549bd0bfa1a569971c77eba30cd5a4b00000000400001445b7b226c616e67223a227a682d434e222c226e616d65223a22e5b08fe89a81e5b881227d2c7b226c616e67223a22656e222c226e616d65223a22416e74436f696e227d5d0000c16ff286230008009f7fd096d37ed2c0e3f7f0cfc924beef4ffceb680000000001000000019b7cffdaa674beae0f930ebe6085af9093e5fe56b34a5c220ccdcf6efc336fc50000c16ff28623005fa99d93303775fe50ca119c327759313eccfa1c01000151"
rawBytes, err := hex.DecodeString(rawHex)
if err != nil {
return Params{}, err
}
reader := bytes.NewReader(rawBytes)
block := payload.Block{}
err = block.Decode(reader)
if err != nil {
return Params{}, err
}
return Params{
GenesisBlock: block,
}, nil
}

View file

@ -0,0 +1,13 @@
package chaincfg
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMainnet(t *testing.T) {
p, err := mainnet()
assert.Nil(t, err)
assert.Equal(t, p.GenesisBlock.Hash.ReverseString(), "d42561e3d30e15be6400b6df2f328e02d2bf6354c41dce433bc57687c82144bf")
}

25
pkg/connmgr/config.go Executable file
View file

@ -0,0 +1,25 @@
package connmgr
import (
"net"
)
// Config contains all methods which will be set by the caller to setup the connection manager.
type Config struct {
// GetAddress will return a single address for the connection manager to connect to
// This will be the source of addresses for the connection manager
GetAddress func() (string, error)
// OnConnection is called by the connection manager when we successfully connect to a peer
// The caller should ideally inform the address manager that we have connected to this address in this function
OnConnection func(conn net.Conn, addr string)
// OnAccept will take an established connection
OnAccept func(net.Conn)
// AddressPort is the address port of the local node in the format "address:port"
AddressPort string
// DialTimeout is the amount of time to wait, before we can disconnect a pending dialed connection
DialTimeout int
}

246
pkg/connmgr/connmgr.go Executable file
View file

@ -0,0 +1,246 @@
package connmgr
import (
"errors"
"fmt"
"net"
"net/http"
"time"
)
var (
// maxOutboundConn is the maximum number of active peers
// that the connection manager will try to have
maxOutboundConn = 10
// maxRetries is the maximum amount of successive retries that
// we can have before we stop dialing that peer
maxRetries = uint8(5)
)
// Connmgr manages pending/active/failed cnnections
type Connmgr struct {
config Config
PendingList map[string]*Request
ConnectedList map[string]*Request
actionch chan func()
}
//New creates a new connection manager
func New(cfg Config) *Connmgr {
cnnmgr := &Connmgr{
cfg,
make(map[string]*Request),
make(map[string]*Request),
make(chan func(), 300),
}
go func() {
listener, err := net.Listen("tcp", cfg.AddressPort)
if err != nil {
fmt.Println("Error connecting to outbound ", err)
}
defer func() {
listener.Close()
}()
for {
conn, err := listener.Accept()
if err != nil {
continue
}
go cfg.OnAccept(conn)
}
}()
return cnnmgr
}
// NewRequest will make a new connection gets the address from address func in config
// Then dials it and assigns it to pending
func (c *Connmgr) NewRequest() error {
// Fetch address
addr, err := c.config.GetAddress()
if err != nil {
return fmt.Errorf("error getting address " + err.Error())
}
r := &Request{
Addr: addr,
}
return c.Connect(r)
}
// Connect will dial the address in the Request
// Updating the request object depending on the outcome
func (c *Connmgr) Connect(r *Request) error {
r.Retries++
conn, err := c.dial(r.Addr)
if err != nil {
c.failed(r)
return err
}
r.Conn = conn
r.Inbound = true
// r.Permanent is set by the address manager/caller. default is false
// The permanent connections will be the ones that are hardcoded, e.g seed3.ngd.network
// or are reliable. The connmgr will be more leniennt to permanent addresses as they have
// a track record or reputation of being reliable.
return c.connected(r)
}
//Disconnect will remove the request from the connected/pending list and close the connection
func (c *Connmgr) Disconnect(addr string) {
var r *Request
// fetch from connected list
r, ok := c.ConnectedList[addr]
if !ok {
// If not in connected, check pending
r, _ = c.PendingList[addr]
}
c.disconnected(r)
}
// Dial is used to dial up connections given the addres and ip in the form address:port
func (c *Connmgr) dial(addr string) (net.Conn, error) {
dialTimeout := 1 * time.Second
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
if err != nil {
if !isConnected() {
return nil, errors.New("Fatal Error: You do not seem to be connected to the internet")
}
return conn, err
}
return conn, nil
}
func (c *Connmgr) failed(r *Request) {
c.actionch <- func() {
// priority to check if it is permanent or inbound
// if so then these peers are valuable in NEO and so we will just retry another time
if r.Inbound || r.Permanent {
multiplier := time.Duration(r.Retries * 10)
time.AfterFunc(multiplier*time.Second,
func() {
c.Connect(r)
},
)
// if not then we should check if this request has had maxRetries
// if it has then get a new address
// if not then call Connect on it again
} else if r.Retries > maxRetries {
if c.config.GetAddress != nil {
go c.NewRequest()
}
} else {
go c.Connect(r)
}
}
}
// Disconnected is called when a peer disconnects.
// we take the addr from peer, which is also it's key in the map
// and we use it to remove it from the connectedList
func (c *Connmgr) disconnected(r *Request) error {
if r == nil {
// if object is nil, we return nil
return nil
}
// if for some reason the underlying connection is not closed, close it
err := r.Conn.Close()
if err != nil {
return err
}
// remove from any pending/connected list
delete(c.PendingList, r.Addr)
delete(c.ConnectedList, r.Addr)
// If permanent,then lets retry
if r.Permanent {
return c.Connect(r)
}
return nil
}
//Connected is called when the connection manager makes a successful connection.
func (c *Connmgr) connected(r *Request) error {
// This should not be the case, since we connected
if r == nil {
return errors.New("request object as nil inside of the connected function")
}
// reset retries to 0
r.Retries = 0
// add to connectedList
c.ConnectedList[r.Addr] = r
// remove from pending if it was there
delete(c.PendingList, r.Addr)
if c.config.OnConnection != nil {
c.config.OnConnection(r.Conn, r.Addr)
}
return nil
}
// Pending is synchronous, we do not want to continue with logic
// until we are certain it has been added to the pendingList
func (c *Connmgr) pending(r *Request) error {
if r == nil {
return errors.New("request object is nil")
}
c.PendingList[r.Addr] = r
return nil
}
// Run will start the connection manager
func (c *Connmgr) Run() error {
fmt.Println("Connection manager started")
go c.loop()
return nil
}
func (c *Connmgr) loop() {
for {
select {
case f := <-c.actionch:
f()
}
}
}
// https://stackoverflow.com/questions/50056144/check-for-internet-connection-from-application
func isConnected() (ok bool) {
_, err := http.Get("http://clients3.google.com/generate_204")
if err != nil {
return false
}
return true
}

107
pkg/connmgr/connmgr_test.go Executable file
View file

@ -0,0 +1,107 @@
package connmgr
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDial(t *testing.T) {
cfg := Config{
GetAddress: nil,
OnConnection: nil,
OnAccept: nil,
AddressPort: "",
DialTimeout: 0,
}
cm := New(cfg)
err := cm.Run()
assert.Equal(t, nil, err)
ipport := "google.com:80" // google unlikely to go offline, a better approach to test Dialing is welcome.
conn, err := cm.dial(ipport)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, conn)
}
func TestConnect(t *testing.T) {
cfg := Config{
GetAddress: nil,
OnConnection: nil,
OnAccept: nil,
AddressPort: "",
DialTimeout: 0,
}
cm := New(cfg)
cm.Run()
ipport := "google.com:80"
r := Request{Addr: ipport}
err := cm.Connect(&r)
assert.Nil(t, err)
assert.Equal(t, 1, len(cm.ConnectedList))
}
func TestNewRequest(t *testing.T) {
address := "google.com:80"
var getAddr = func() (string, error) {
return address, nil
}
cfg := Config{
GetAddress: getAddr,
OnConnection: nil,
OnAccept: nil,
AddressPort: "",
DialTimeout: 0,
}
cm := New(cfg)
cm.Run()
cm.NewRequest()
if _, ok := cm.ConnectedList[address]; ok {
assert.Equal(t, true, ok)
assert.Equal(t, 1, len(cm.ConnectedList))
return
}
assert.Fail(t, "Could not find the address in the connected lists")
}
func TestDisconnect(t *testing.T) {
address := "google.com:80"
var getAddr = func() (string, error) {
return address, nil
}
cfg := Config{
GetAddress: getAddr,
OnConnection: nil,
OnAccept: nil,
AddressPort: "",
DialTimeout: 0,
}
cm := New(cfg)
cm.Run()
cm.NewRequest()
cm.Disconnect(address)
assert.Equal(t, 0, len(cm.ConnectedList))
}

22
pkg/connmgr/readme.md Executable file
View file

@ -0,0 +1,22 @@
# Package - Connection Manager
## Responsibility
- Manages the active, failed and pending connections for the node.
## Features
- Takes an Request, dials it and logs information based on the connectivity.
- Retry failed connections.
- Removable address source. The connection manager does not manage addresses, only connections.
## Usage
The following methods are exposed from the Connection manager:
- Connect(r *Request) : This takes a Request object and connects to it. It follow the same logic as NewRequest() however instead of getting the address from the datasource given upon initialisation, you directly feed the address you want to connect to.
- Disconnect(addrport string) : Given an address:port, this will disconnect it, close the connection and remove it from the connected and pending list, if it was there.

15
pkg/connmgr/request.go Executable file
View file

@ -0,0 +1,15 @@
package connmgr
import (
"net"
)
// Request is a layer on top of connection and allows us to add metadata to the net.Conn
// that the connection manager can use to determine whether to retry and other useful heuristics
type Request struct {
Conn net.Conn
Addr string
Permanent bool
Inbound bool
Retries uint8 // should not be trying more than 255 tries
}

View file

@ -1,8 +1,11 @@
package base58 package base58
import ( import (
"bytes"
"fmt" "fmt"
"math/big" "math/big"
"github.com/CityOfZion/neo-go/pkg/crypto/hash"
) )
const prefix rune = '1' const prefix rune = '1'
@ -76,3 +79,48 @@ func Encode(bytes []byte) string {
return encoded return encoded
} }
// CheckDecode decodes the given string.
func CheckDecode(s string) (b []byte, err error) {
b, err = Decode(s)
if err != nil {
return nil, err
}
for i := 0; i < len(s); i++ {
if s[i] != '1' {
break
}
b = append([]byte{0x00}, b...)
}
if len(b) < 5 {
return nil, fmt.Errorf("Invalid base-58 check string: missing checksum")
}
hash, err := hash.DoubleSha256(b[:len(b)-4])
if err != nil {
return nil, fmt.Errorf("Could not double sha256 data")
}
if bytes.Compare(hash[0:4], b[len(b)-4:]) != 0 {
return nil, fmt.Errorf("Invalid base-58 check string: invalid checksum")
}
// Strip the 4 byte long hash.
b = b[:len(b)-4]
return b, nil
}
// CheckEncode encodes b into a base-58 check encoded string.
func CheckEncode(b []byte) (string, error) {
hash, err := hash.DoubleSha256(b)
if err != nil {
return "", fmt.Errorf("Could not double sha256 data")
}
b = append(b, hash[0:4]...)
return Encode(b), nil
}

View file

@ -1,23 +1,24 @@
package database package database
import ( import (
"bytes"
"encoding/binary"
"fmt"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/CityOfZion/neo-go/pkg/wire/payload/transaction"
"github.com/CityOfZion/neo-go/pkg/wire/util"
"github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/errors" "github.com/syndtr/goleveldb/leveldb/errors"
ldbutil "github.com/syndtr/goleveldb/leveldb/util"
) )
//DbDir is the folder which all database files will be put under
// Structure /DbDir/net
const DbDir = "db/"
// LDB represents a leveldb object // LDB represents a leveldb object
type LDB struct { type LDB struct {
db *leveldb.DB db *leveldb.DB
path string Path string
} }
// ErrNotFound means that the value was not found in the db
var ErrNotFound = errors.New("value not found for that key")
// Database contains all methods needed for an object to be a database // Database contains all methods needed for an object to be a database
type Database interface { type Database interface {
// Has checks whether the key is in the database // Has checks whether the key is in the database
@ -28,37 +29,30 @@ type Database interface {
Get(key []byte) ([]byte, error) Get(key []byte) ([]byte, error)
// Delete deletes the given value for the key from the database // Delete deletes the given value for the key from the database
Delete(key []byte) error Delete(key []byte) error
//Prefix returns all values that start with key
Prefix(key []byte) ([][]byte, error)
// Close closes the underlying db object // Close closes the underlying db object
Close() error Close() error
} }
var (
// TX is the prefix used when inserting a tx into the db
TX = []byte("TX")
// HEADER is the prefix used when inserting a header into the db
HEADER = []byte("HEADER")
// LATESTHEADER is the prefix used when inserting the latests header into the db
LATESTHEADER = []byte("LH")
// UTXO is the prefix used when inserting a utxo into the db
UTXO = []byte("UTXO")
)
// New will return a new leveldb instance // New will return a new leveldb instance
func New(path string) *LDB { func New(path string) (*LDB, error) {
db, err := leveldb.OpenFile(path, nil) dbPath := DbDir + path
db, err := leveldb.OpenFile(dbPath, nil)
if err != nil {
return nil, err
}
if _, corrupted := err.(*errors.ErrCorrupted); corrupted { if _, corrupted := err.(*errors.ErrCorrupted); corrupted {
db, err = leveldb.RecoverFile(path, nil) db, err = leveldb.RecoverFile(path, nil)
} if err != nil {
return nil, err
if err != nil { }
return nil
} }
return &LDB{ return &LDB{
db, db,
path, dbPath,
} }, nil
} }
// Has implements the database interface // Has implements the database interface
@ -73,7 +67,15 @@ func (l *LDB) Put(key []byte, value []byte) error {
// Get implements the database interface // Get implements the database interface
func (l *LDB) Get(key []byte) ([]byte, error) { func (l *LDB) Get(key []byte) ([]byte, error) {
return l.db.Get(key, nil) val, err := l.db.Get(key, nil)
if err == nil {
return val, nil
}
if err == leveldb.ErrNotFound {
return val, ErrNotFound
}
return val, err
} }
// Delete implements the database interface // Delete implements the database interface
@ -86,79 +88,27 @@ func (l *LDB) Close() error {
return l.db.Close() return l.db.Close()
} }
// AddHeader adds a header into the database // Prefix implements the database interface
func (l *LDB) AddHeader(header *payload.BlockBase) error { func (l *LDB) Prefix(key []byte) ([][]byte, error) {
table := NewTable(l, HEADER) var results [][]byte
iter := l.db.NewIterator(ldbutil.BytesPrefix(key), nil)
for iter.Next() {
value := iter.Value()
// Copy the data, as we cannot modify it
// Once the iter has been released
deref := make([]byte, len(value))
copy(deref, value)
// Append result
results = append(results, deref)
byt, err := header.Bytes()
if err != nil {
fmt.Println("Could not Get bytes from decoded BlockBase")
return nil
} }
iter.Release()
fmt.Println("Adding Header, This should be batched!!!!") err := iter.Error()
return results, err
// This is the main mapping
//Key: HEADER+BLOCKHASH Value: contents of blockhash
key := header.Hash.Bytes()
err = table.Put(key, byt)
if err != nil {
fmt.Println("Error trying to add the original mapping into the DB for Header. Mapping is [Header]+[Hash]")
return err
}
// This is the secondary mapping
// Key: HEADER + BLOCKHEIGHT Value: blockhash
bh := uint32ToBytes(header.Index)
key = []byte(bh)
err = table.Put(key, header.Hash.Bytes())
if err != nil {
return err
}
// This is the third mapping
// WARNING: This assumes that headers are adding in order.
return table.Put(LATESTHEADER, header.Hash.Bytes())
}
// AddTransactions adds a set of transactions into the database
func (l *LDB) AddTransactions(blockhash util.Uint256, txs []transaction.Transactioner) error {
// SHOULD BE DONE IN BATCH!!!!
for i, tx := range txs {
buf := new(bytes.Buffer)
fmt.Println(tx.ID())
tx.Encode(buf)
txByt := buf.Bytes()
txhash, err := tx.ID()
if err != nil {
fmt.Println("Error adding transaction with bytes", txByt)
return err
}
// This is the original mapping
// Key: [TX] + TXHASH
key := append(TX, txhash.Bytes()...)
l.Put(key, txByt)
// This is the index
// Key: [TX] + BLOCKHASH + I <- i is the incrementer from the for loop
//Value : TXHASH
key = append(TX, blockhash.Bytes()...)
key = append(key, uint32ToBytes(uint32(i))...)
err = l.Put(key, txhash.Bytes())
if err != nil {
fmt.Println("Error could not add tx index into db")
return err
}
}
return nil
}
// BigEndian
func uint32ToBytes(h uint32) []byte {
a := make([]byte, 4)
binary.BigEndian.PutUint32(a, h)
return a
} }

View file

@ -6,27 +6,31 @@ import (
"github.com/CityOfZion/neo-go/pkg/database" "github.com/CityOfZion/neo-go/pkg/database"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/syndtr/goleveldb/leveldb/errors"
) )
const path = "temp" const path = "temp"
func cleanup(db *database.LDB) { func cleanup(db *database.LDB) {
db.Close() db.Close()
os.RemoveAll(path) os.RemoveAll(database.DbDir)
} }
func TestDBCreate(t *testing.T) { func TestDBCreate(t *testing.T) {
db := database.New(path)
db, err := database.New(path)
assert.Nil(t, err)
assert.NotEqual(t, nil, db) assert.NotEqual(t, nil, db)
cleanup(db) cleanup(db)
} }
func TestPutGet(t *testing.T) { func TestPutGet(t *testing.T) {
db := database.New(path)
db, err := database.New(path)
assert.Nil(t, err)
key := []byte("Hello") key := []byte("Hello")
value := []byte("World") value := []byte("World")
err := db.Put(key, value) err = db.Put(key, value)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
res, err := db.Get(key) res, err := db.Get(key)
@ -36,25 +40,28 @@ func TestPutGet(t *testing.T) {
} }
func TestPutDelete(t *testing.T) { func TestPutDelete(t *testing.T) {
db := database.New(path) db, err := database.New(path)
assert.Nil(t, err)
key := []byte("Hello") key := []byte("Hello")
value := []byte("World") value := []byte("World")
err := db.Put(key, value) err = db.Put(key, value)
err = db.Delete(key) err = db.Delete(key)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
res, err := db.Get(key) res, err := db.Get(key)
assert.Equal(t, errors.ErrNotFound, err) assert.Equal(t, database.ErrNotFound, err)
assert.Equal(t, res, []byte{}) assert.Equal(t, res, []byte{})
cleanup(db) cleanup(db)
} }
func TestHas(t *testing.T) { func TestHas(t *testing.T) {
db := database.New("temp")
db, err := database.New(path)
assert.Nil(t, err)
res, err := db.Has([]byte("NotExist")) res, err := db.Has([]byte("NotExist"))
assert.Equal(t, res, false) assert.Equal(t, res, false)
@ -73,8 +80,12 @@ func TestHas(t *testing.T) {
} }
func TestDBClose(t *testing.T) { func TestDBClose(t *testing.T) {
db := database.New("temp")
err := db.Close() db, err := database.New(path)
assert.Nil(t, err)
err = db.Close()
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
cleanup(db) cleanup(db)
} }

View file

@ -16,29 +16,35 @@ func NewTable(db Database, prefix []byte) *Table {
// Has implements the database interface // Has implements the database interface
func (t *Table) Has(key []byte) (bool, error) { func (t *Table) Has(key []byte) (bool, error) {
key = append(t.prefix, key...) prefixedKey := append(t.prefix, key...)
return t.db.Has(key) return t.db.Has(prefixedKey)
} }
// Put implements the database interface // Put implements the database interface
func (t *Table) Put(key []byte, value []byte) error { func (t *Table) Put(key []byte, value []byte) error {
key = append(t.prefix, key...) prefixedKey := append(t.prefix, key...)
return t.db.Put(key, value) return t.db.Put(prefixedKey, value)
} }
// Get implements the database interface // Get implements the database interface
func (t *Table) Get(key []byte) ([]byte, error) { func (t *Table) Get(key []byte) ([]byte, error) {
key = append(t.prefix, key...) prefixedKey := append(t.prefix, key...)
return t.db.Get(key) return t.db.Get(prefixedKey)
} }
// Delete implements the database interface // Delete implements the database interface
func (t *Table) Delete(key []byte) error { func (t *Table) Delete(key []byte) error {
key = append(t.prefix, key...) prefixedKey := append(t.prefix, key...)
return t.db.Delete(key) return t.db.Delete(prefixedKey)
} }
// Close implements the database interface // Close implements the database interface
func (t *Table) Close() error { func (t *Table) Close() error {
return nil return nil
} }
// Prefix implements the database interface
func (t *Table) Prefix(key []byte) ([][]byte, error) {
prefixedKey := append(t.prefix, key...)
return t.db.Prefix(prefixedKey)
}

View file

@ -14,34 +14,18 @@ type LocalConfig struct {
ProtocolVer protocol.Version ProtocolVer protocol.Version
Relay bool Relay bool
Port uint16 Port uint16
// pointer to config will keep the startheight updated for each version
//Message we plan to send // pointer to config will keep the startheight updated
StartHeight func() uint32 StartHeight func() uint32
// Response Handlers
OnHeader func(*Peer, *payload.HeadersMessage) OnHeader func(*Peer, *payload.HeadersMessage)
OnGetHeaders func(msg *payload.GetHeadersMessage) // returns HeaderMessage OnGetHeaders func(*Peer, *payload.GetHeadersMessage)
OnAddr func(*Peer, *payload.AddrMessage) OnAddr func(*Peer, *payload.AddrMessage)
OnGetAddr func(*Peer, *payload.GetAddrMessage) OnGetAddr func(*Peer, *payload.GetAddrMessage)
OnInv func(*Peer, *payload.InvMessage) OnInv func(*Peer, *payload.InvMessage)
OnGetData func(msg *payload.GetDataMessage) OnGetData func(*Peer, *payload.GetDataMessage)
OnBlock func(*Peer, *payload.BlockMessage) OnBlock func(*Peer, *payload.BlockMessage)
OnGetBlocks func(msg *payload.GetBlocksMessage) OnGetBlocks func(*Peer, *payload.GetBlocksMessage)
OnTx func(*Peer, *payload.TXMessage)
} }
// func DefaultConfig() LocalConfig {
// return LocalConfig{
// Net: protocol.MainNet,
// UserAgent: "NEO-GO-Default",
// Services: protocol.NodePeerService,
// Nonce: 1200,
// ProtocolVer: 0,
// Relay: false,
// Port: 10332,
// // pointer to config will keep the startheight updated for each version
// //Message we plan to send
// StartHeight: DefaultHeight,
// }
// }
// func DefaultHeight() uint32 {
// return 10
// }

View file

@ -58,6 +58,8 @@ type Peer struct {
config LocalConfig config LocalConfig
conn net.Conn conn net.Conn
startHeight uint32
// atomic vals // atomic vals
disconnected int32 disconnected int32
@ -84,20 +86,18 @@ type Peer struct {
// NewPeer returns a new NEO peer // NewPeer returns a new NEO peer
func NewPeer(con net.Conn, inbound bool, cfg LocalConfig) *Peer { func NewPeer(con net.Conn, inbound bool, cfg LocalConfig) *Peer {
p := Peer{} return &Peer{
p.inch = make(chan func(), inputBufferSize) inch: make(chan func(), inputBufferSize),
p.outch = make(chan func(), outputBufferSize) outch: make(chan func(), outputBufferSize),
p.quitch = make(chan struct{}, 1) quitch: make(chan struct{}, 1),
p.inbound = inbound inbound: inbound,
p.config = cfg config: cfg,
p.conn = con conn: con,
p.createdAt = time.Now() createdAt: time.Now(),
p.addr = p.conn.RemoteAddr().String() startHeight: 0,
addr: con.RemoteAddr().String(),
p.Detector = stall.NewDetector(responseTime, tickerInterval) Detector: stall.NewDetector(responseTime, tickerInterval),
}
// TODO: set the unchangeable states
return &p
} }
// Write to a peer // Write to a peer
@ -125,7 +125,6 @@ func (p *Peer) Disconnect() {
p.conn.Close() p.conn.Close()
fmt.Println("Disconnected Peer with address", p.RemoteAddr().String()) fmt.Println("Disconnected Peer with address", p.RemoteAddr().String())
} }
// Port returns the peers port // Port returns the peers port
@ -138,6 +137,11 @@ func (p *Peer) CreatedAt() time.Time {
return p.createdAt return p.createdAt
} }
// Height returns the latest recorded height of this peer
func (p *Peer) Height() uint32 {
return p.startHeight
}
// CanRelay returns true, if the peer can relay information // CanRelay returns true, if the peer can relay information
func (p *Peer) CanRelay() bool { func (p *Peer) CanRelay() bool {
return p.relay return p.relay
@ -163,11 +167,6 @@ func (p *Peer) Inbound() bool {
return p.inbound return p.inbound
} }
// UserAgent returns this nodes, useragent
func (p *Peer) UserAgent() string {
return p.config.UserAgent
}
// IsVerackReceived returns true, if this node has // IsVerackReceived returns true, if this node has
// received a verack from this peer // received a verack from this peer
func (p *Peer) IsVerackReceived() bool { func (p *Peer) IsVerackReceived() bool {
@ -176,11 +175,9 @@ func (p *Peer) IsVerackReceived() bool {
//NotifyDisconnect returns once the peer has disconnected //NotifyDisconnect returns once the peer has disconnected
// Blocking // Blocking
func (p *Peer) NotifyDisconnect() bool { func (p *Peer) NotifyDisconnect() {
fmt.Println("Peer has not disconnected yet")
<-p.quitch <-p.quitch
fmt.Println("Peer has just disconnected") fmt.Println("Peer has just disconnected")
return true
} }
//End of Exposed API functions// //End of Exposed API functions//
@ -195,14 +192,15 @@ func (p *Peer) PingLoop() { /*not implemented in other neo clients*/ }
func (p *Peer) Run() error { func (p *Peer) Run() error {
err := p.Handshake() err := p.Handshake()
if err != nil {
return err
}
go p.StartProtocol() go p.StartProtocol()
go p.ReadLoop() go p.ReadLoop()
go p.WriteLoop() go p.WriteLoop()
//go p.PingLoop() // since it is not implemented. It will disconnect all other impls. //go p.PingLoop() // since it is not implemented. It will disconnect all other impls.
return err return nil
} }
// StartProtocol run as a go-routine, will act as our queue for messages // StartProtocol run as a go-routine, will act as our queue for messages
@ -303,128 +301,17 @@ func (p *Peer) WriteLoop() {
} }
} }
// OnGetData is called when a GetData message is received // Outgoing Requests
func (p *Peer) OnGetData(msg *payload.GetDataMessage) {
p.inch <- func() {
if p.config.OnInv != nil {
p.config.OnGetData(msg)
}
fmt.Println("That was an getdata Message please pass func down through config", msg.Command())
}
}
//OnTX is callwed when a TX message is received
func (p *Peer) OnTX(msg *payload.TXMessage) {
p.inch <- func() {
getdata, err := payload.NewGetDataMessage(payload.InvTypeTx)
if err != nil {
fmt.Println("Eor", err)
}
id, err := msg.Tx.ID()
getdata.AddHash(id)
p.Write(getdata)
}
}
// OnInv is called when a Inv message is received
func (p *Peer) OnInv(msg *payload.InvMessage) {
p.inch <- func() {
if p.config.OnInv != nil {
p.config.OnInv(p, msg)
}
fmt.Println("That was an inv Message please pass func down through config", msg.Command())
}
}
// OnGetHeaders is called when a GetHeaders message is received
func (p *Peer) OnGetHeaders(msg *payload.GetHeadersMessage) {
p.inch <- func() {
if p.config.OnGetHeaders != nil {
p.config.OnGetHeaders(msg)
}
fmt.Println("That was a getheaders message, please pass func down through config", msg.Command())
}
}
// OnAddr is called when a Addr message is received
func (p *Peer) OnAddr(msg *payload.AddrMessage) {
p.inch <- func() {
if p.config.OnAddr != nil {
p.config.OnAddr(p, msg)
}
fmt.Println("That was a addr message, please pass func down through config", msg.Command())
}
}
// OnGetAddr is called when a GetAddr message is received
func (p *Peer) OnGetAddr(msg *payload.GetAddrMessage) {
p.inch <- func() {
if p.config.OnGetAddr != nil {
p.config.OnGetAddr(p, msg)
}
fmt.Println("That was a getaddr message, please pass func down through config", msg.Command())
}
}
// OnGetBlocks is called when a GetBlocks message is received
func (p *Peer) OnGetBlocks(msg *payload.GetBlocksMessage) {
p.inch <- func() {
if p.config.OnGetBlocks != nil {
p.config.OnGetBlocks(msg)
}
fmt.Println("That was a getblocks message, please pass func down through config", msg.Command())
}
}
// OnBlocks is called when a Blocks message is received
func (p *Peer) OnBlocks(msg *payload.BlockMessage) {
p.inch <- func() {
if p.config.OnBlock != nil {
p.config.OnBlock(p, msg)
}
}
}
// OnVersion Listener will be called
// during the handshake, any error checking should be done here for the versionMessage.
// This should only ever be called during the handshake. Any other place and the peer will disconnect.
func (p *Peer) OnVersion(msg *payload.VersionMessage) error {
if msg.Nonce == p.config.Nonce {
p.conn.Close()
return errors.New("Self connection, disconnecting Peer")
}
p.versionKnown = true
p.port = msg.Port
p.services = msg.Services
p.userAgent = string(msg.UserAgent)
p.createdAt = time.Now()
p.relay = msg.Relay
return nil
}
// OnHeaders is called when a Headers message is received
func (p *Peer) OnHeaders(msg *payload.HeadersMessage) {
fmt.Println("We have received the headers")
p.inch <- func() {
if p.config.OnHeader != nil {
p.config.OnHeader(p, msg)
}
}
}
// RequestHeaders will write a getheaders to this peer // RequestHeaders will write a getheaders to this peer
func (p *Peer) RequestHeaders(hash util.Uint256) error { func (p *Peer) RequestHeaders(hash util.Uint256) error {
c := make(chan error, 0) c := make(chan error, 0)
p.outch <- func() { p.outch <- func() {
p.Detector.AddMessage(command.GetHeaders)
getHeaders, err := payload.NewGetHeadersMessage([]util.Uint256{hash}, util.Uint256{}) getHeaders, err := payload.NewGetHeadersMessage([]util.Uint256{hash}, util.Uint256{})
err = p.Write(getHeaders) err = p.Write(getHeaders)
if err != nil {
p.Detector.AddMessage(command.GetHeaders)
}
c <- err c <- err
} }
return <-c return <-c
@ -435,17 +322,19 @@ func (p *Peer) RequestBlocks(hashes []util.Uint256) error {
c := make(chan error, 0) c := make(chan error, 0)
p.outch <- func() { p.outch <- func() {
p.Detector.AddMessage(command.GetData)
getdata, err := payload.NewGetDataMessage(payload.InvTypeBlock) getdata, err := payload.NewGetDataMessage(payload.InvTypeBlock)
err = getdata.AddHashes(hashes) err = getdata.AddHashes(hashes)
if err != nil { if err != nil {
c <- err c <- err
return return
} }
err = p.Write(getdata) err = p.Write(getdata)
if err != nil {
p.Detector.AddMessage(command.GetData)
}
c <- err c <- err
} }
return <-c return <-c
} }

View file

@ -1,7 +1,6 @@
package peer_test package peer_test
import ( import (
"fmt"
"net" "net"
"testing" "testing"
"time" "time"
@ -21,11 +20,11 @@ func returnConfig() peer.LocalConfig {
OnAddr := func(p *peer.Peer, msg *payload.AddrMessage) {} OnAddr := func(p *peer.Peer, msg *payload.AddrMessage) {}
OnHeader := func(p *peer.Peer, msg *payload.HeadersMessage) {} OnHeader := func(p *peer.Peer, msg *payload.HeadersMessage) {}
OnGetHeaders := func(msg *payload.GetHeadersMessage) {} OnGetHeaders := func(p *peer.Peer, msg *payload.GetHeadersMessage) {}
OnInv := func(p *peer.Peer, msg *payload.InvMessage) {} OnInv := func(p *peer.Peer, msg *payload.InvMessage) {}
OnGetData := func(msg *payload.GetDataMessage) {} OnGetData := func(p *peer.Peer, msg *payload.GetDataMessage) {}
OnBlock := func(p *peer.Peer, msg *payload.BlockMessage) {} OnBlock := func(p *peer.Peer, msg *payload.BlockMessage) {}
OnGetBlocks := func(msg *payload.GetBlocksMessage) {} OnGetBlocks := func(p *peer.Peer, msg *payload.GetBlocksMessage) {}
return peer.LocalConfig{ return peer.LocalConfig{
Net: protocol.MainNet, Net: protocol.MainNet,
@ -157,17 +156,9 @@ func TestConfigurations(t *testing.T) {
assert.Equal(t, config.Services, p.Services()) assert.Equal(t, config.Services, p.Services())
assert.Equal(t, config.UserAgent, p.UserAgent())
assert.Equal(t, config.Relay, p.CanRelay()) assert.Equal(t, config.Relay, p.CanRelay())
assert.WithinDuration(t, time.Now(), p.CreatedAt(), 1*time.Second) assert.WithinDuration(t, time.Now(), p.CreatedAt(), 1*time.Second)
}
func TestHandshakeCancelled(t *testing.T) {
// These are the conditions which should invalidate the handshake.
// Make sure peer is disconnected.
} }
func TestPeerDisconnect(t *testing.T) { func TestPeerDisconnect(t *testing.T) {
@ -178,21 +169,17 @@ func TestPeerDisconnect(t *testing.T) {
inbound := true inbound := true
config := returnConfig() config := returnConfig()
p := peer.NewPeer(conn, inbound, config) p := peer.NewPeer(conn, inbound, config)
fmt.Println("Calling disconnect")
p.Disconnect() p.Disconnect()
fmt.Println("Disconnect finished calling") verack, err := payload.NewVerackMessage()
verack, _ := payload.NewVerackMessage() assert.Nil(t, err)
fmt.Println(" We good here") err = p.Write(verack)
assert.NotNil(t, err)
err := p.Write(verack) // Check if stall detector is still running
assert.NotEqual(t, err, nil)
// Check if Stall detector is still running
_, ok := <-p.Detector.Quitch _, ok := <-p.Detector.Quitch
assert.Equal(t, ok, false) assert.Equal(t, ok, false)
} }
func TestNotifyDisconnect(t *testing.T) { func TestNotifyDisconnect(t *testing.T) {

View file

@ -0,0 +1,111 @@
package peer
import (
"errors"
"time"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
)
// OnGetData is called when a GetData message is received
func (p *Peer) OnGetData(msg *payload.GetDataMessage) {
p.inch <- func() {
if p.config.OnInv != nil {
p.config.OnGetData(p, msg)
}
}
}
//OnTX is called when a TX message is received
func (p *Peer) OnTX(msg *payload.TXMessage) {
p.inch <- func() {
p.inch <- func() {
if p.config.OnTx != nil {
p.config.OnTx(p, msg)
}
}
}
}
// OnInv is called when a Inv message is received
func (p *Peer) OnInv(msg *payload.InvMessage) {
p.inch <- func() {
if p.config.OnInv != nil {
p.config.OnInv(p, msg)
}
}
}
// OnGetHeaders is called when a GetHeaders message is received
func (p *Peer) OnGetHeaders(msg *payload.GetHeadersMessage) {
p.inch <- func() {
if p.config.OnGetHeaders != nil {
p.config.OnGetHeaders(p, msg)
}
}
}
// OnAddr is called when a Addr message is received
func (p *Peer) OnAddr(msg *payload.AddrMessage) {
p.inch <- func() {
if p.config.OnAddr != nil {
p.config.OnAddr(p, msg)
}
}
}
// OnGetAddr is called when a GetAddr message is received
func (p *Peer) OnGetAddr(msg *payload.GetAddrMessage) {
p.inch <- func() {
if p.config.OnGetAddr != nil {
p.config.OnGetAddr(p, msg)
}
}
}
// OnGetBlocks is called when a GetBlocks message is received
func (p *Peer) OnGetBlocks(msg *payload.GetBlocksMessage) {
p.inch <- func() {
if p.config.OnGetBlocks != nil {
p.config.OnGetBlocks(p, msg)
}
}
}
// OnBlocks is called when a Blocks message is received
func (p *Peer) OnBlocks(msg *payload.BlockMessage) {
p.Detector.RemoveMessage(msg.Command())
p.inch <- func() {
if p.config.OnBlock != nil {
p.config.OnBlock(p, msg)
}
}
}
// OnHeaders is called when a Headers message is received
func (p *Peer) OnHeaders(msg *payload.HeadersMessage) {
p.Detector.RemoveMessage(msg.Command())
p.inch <- func() {
if p.config.OnHeader != nil {
p.config.OnHeader(p, msg)
}
}
}
// OnVersion Listener will be called
// during the handshake, any error checking should be done here for the versionMessage.
// This should only ever be called during the handshake. Any other place and the peer will disconnect.
func (p *Peer) OnVersion(msg *payload.VersionMessage) error {
if msg.Nonce == p.config.Nonce {
p.conn.Close()
return errors.New("self connection, disconnecting Peer")
}
p.versionKnown = true
p.port = msg.Port
p.services = msg.Services
p.userAgent = string(msg.UserAgent)
p.createdAt = time.Now()
p.relay = msg.Relay
p.startHeight = msg.StartHeight
return nil
}

View file

@ -61,6 +61,7 @@ func (d *Detector) loop() {
d.lock.RUnlock() d.lock.RUnlock()
for _, deadline := range resp { for _, deadline := range resp {
if now.After(deadline) { if now.After(deadline) {
fmt.Println(resp)
fmt.Println("Deadline passed") fmt.Println("Deadline passed")
return return
} }
@ -99,7 +100,7 @@ func (d *Detector) AddMessage(cmd command.Type) {
// peer. This will remove the pendingresponse message from the map. // peer. This will remove the pendingresponse message from the map.
// The command passed through is the command we received // The command passed through is the command we received
func (d *Detector) RemoveMessage(cmd command.Type) { func (d *Detector) RemoveMessage(cmd command.Type) {
cmds := d.addMessage(cmd) cmds := d.removeMessage(cmd)
d.lock.Lock() d.lock.Lock()
for _, cmd := range cmds { for _, cmd := range cmds {
delete(d.responses, cmd) delete(d.responses, cmd)
@ -137,10 +138,8 @@ func (d *Detector) addMessage(cmd command.Type) []command.Type {
case command.GetAddr: case command.GetAddr:
// We now will expect a Headers Message // We now will expect a Headers Message
cmds = append(cmds, command.Addr) cmds = append(cmds, command.Addr)
case command.GetData: case command.GetData:
// We will now expect a block/tx message // We will now expect a block/tx message
// We can optimise this by including the exact inventory type, however it is not needed
cmds = append(cmds, command.Block) cmds = append(cmds, command.Block)
cmds = append(cmds, command.TX) cmds = append(cmds, command.TX)
case command.GetBlocks: case command.GetBlocks:
@ -159,19 +158,18 @@ func (d *Detector) removeMessage(cmd command.Type) []command.Type {
switch cmd { switch cmd {
case command.Block: case command.Block:
// We will now expect a block/tx message // We will now remove a block and tx message
cmds = append(cmds, command.Block) cmds = append(cmds, command.Block)
cmds = append(cmds, command.TX) cmds = append(cmds, command.TX)
case command.TX: case command.TX:
// We will now expect a block/tx message // We will now remove a block and tx message
cmds = append(cmds, command.Block) cmds = append(cmds, command.Block)
cmds = append(cmds, command.TX) cmds = append(cmds, command.TX)
case command.GetBlocks: case command.Verack:
// we will now expect a inv message
cmds = append(cmds, command.Inv)
default:
// We will now expect a verack // We will now expect a verack
cmds = append(cmds, cmd) cmds = append(cmds, cmd)
default:
cmds = append(cmds, cmd)
} }
return cmds return cmds
} }

View file

@ -22,7 +22,7 @@ func TestAddRemoveMessage(t *testing.T) {
assert.Equal(t, 1, len(mp)) assert.Equal(t, 1, len(mp))
assert.IsType(t, time.Time{}, mp[command.GetAddr]) assert.IsType(t, time.Time{}, mp[command.GetAddr])
d.RemoveMessage(command.GetAddr) d.RemoveMessage(command.Addr)
mp = d.GetMessages() mp = d.GetMessages()
assert.Equal(t, 0, len(mp)) assert.Equal(t, 0, len(mp))

155
pkg/peermgr/blockcache.go Normal file
View file

@ -0,0 +1,155 @@
package peermgr
import (
"errors"
"sort"
"sync"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
var (
//ErrCacheLimit is returned when the cache limit is reached
ErrCacheLimit = errors.New("nomore items can be added to the cache")
//ErrNoItems is returned when pickItem is called and there are no items in the cache
ErrNoItems = errors.New("there are no items in the cache")
//ErrDuplicateItem is returned when you try to add the same item, more than once to the cache
ErrDuplicateItem = errors.New("this item is already in the cache")
)
//BlockInfo holds the necessary information that the cache needs
// to sort and store block requests
type BlockInfo struct {
BlockHash util.Uint256
BlockIndex uint32
}
// Equals returns true if two blockInfo objects
// have the same hash and the same index
func (bi *BlockInfo) Equals(other BlockInfo) bool {
return bi.BlockHash.Equals(other.BlockHash) && bi.BlockIndex == other.BlockIndex
}
// indexSorter sorts the blockInfos by blockIndex.
type indexSorter []BlockInfo
func (is indexSorter) Len() int { return len(is) }
func (is indexSorter) Swap(i, j int) { is[i], is[j] = is[j], is[i] }
func (is indexSorter) Less(i, j int) bool { return is[i].BlockIndex < is[j].BlockIndex }
//blockCache will cache any pending block requests
// for the node when there are no available nodes
type blockCache struct {
cacheLimit int
cacheLock sync.Mutex
cache []BlockInfo
}
func newBlockCache(cacheLimit int) *blockCache {
return &blockCache{
cache: make([]BlockInfo, 0, cacheLimit),
cacheLimit: cacheLimit,
}
}
func (bc *blockCache) addBlockInfo(bi BlockInfo) error {
if bc.cacheLen() == bc.cacheLimit {
return ErrCacheLimit
}
bc.cacheLock.Lock()
defer bc.cacheLock.Unlock()
// Check for duplicates. slice will always be small so a simple for loop will work
for _, bInfo := range bc.cache {
if bInfo.Equals(bi) {
return ErrDuplicateItem
}
}
bc.cache = append(bc.cache, bi)
sort.Sort(indexSorter(bc.cache))
return nil
}
func (bc *blockCache) addBlockInfos(bis []BlockInfo) error {
if len(bis)+bc.cacheLen() > bc.cacheLimit {
return errors.New("too many items to add, this will exceed the cache limit")
}
for _, bi := range bis {
err := bc.addBlockInfo(bi)
if err != nil {
return err
}
}
return nil
}
func (bc *blockCache) cacheLen() int {
bc.cacheLock.Lock()
defer bc.cacheLock.Unlock()
return len(bc.cache)
}
func (bc *blockCache) pickFirstItem() (BlockInfo, error) {
return bc.pickItem(0)
}
func (bc *blockCache) pickAllItems() ([]BlockInfo, error) {
numOfItems := bc.cacheLen()
items := make([]BlockInfo, 0, numOfItems)
for i := 0; i < numOfItems; i++ {
bi, err := bc.pickFirstItem()
if err != nil {
return nil, err
}
items = append(items, bi)
}
return items, nil
}
func (bc *blockCache) pickItem(i uint) (BlockInfo, error) {
if bc.cacheLen() < 1 {
return BlockInfo{}, ErrNoItems
}
if i >= uint(bc.cacheLen()) {
return BlockInfo{}, errors.New("index out of range")
}
bc.cacheLock.Lock()
defer bc.cacheLock.Unlock()
item := bc.cache[i]
bc.cache = append(bc.cache[:i], bc.cache[i+1:]...)
return item, nil
}
func (bc *blockCache) removeHash(hashToRemove util.Uint256) error {
index, err := bc.findHash(hashToRemove)
if err != nil {
return err
}
_, err = bc.pickItem(uint(index))
return err
}
func (bc *blockCache) findHash(hashToFind util.Uint256) (int, error) {
bc.cacheLock.Lock()
defer bc.cacheLock.Unlock()
for i, bInfo := range bc.cache {
if bInfo.BlockHash.Equals(hashToFind) {
return i, nil
}
}
return -1, errors.New("hash cannot be found in the cache")
}

View file

@ -0,0 +1,80 @@
package peermgr
import (
"math/rand"
"testing"
"github.com/CityOfZion/neo-go/pkg/wire/util"
"github.com/stretchr/testify/assert"
)
func TestAddBlock(t *testing.T) {
bc := &blockCache{
cacheLimit: 20,
}
bi := randomBlockInfo(t)
err := bc.addBlockInfo(bi)
assert.Equal(t, nil, err)
assert.Equal(t, 1, bc.cacheLen())
err = bc.addBlockInfo(bi)
assert.Equal(t, ErrDuplicateItem, err)
assert.Equal(t, 1, bc.cacheLen())
}
func TestCacheLimit(t *testing.T) {
bc := &blockCache{
cacheLimit: 20,
}
for i := 0; i < bc.cacheLimit; i++ {
err := bc.addBlockInfo(randomBlockInfo(t))
assert.Equal(t, nil, err)
}
err := bc.addBlockInfo(randomBlockInfo(t))
assert.Equal(t, ErrCacheLimit, err)
assert.Equal(t, bc.cacheLimit, bc.cacheLen())
}
func TestPickItem(t *testing.T) {
bc := &blockCache{
cacheLimit: 20,
}
for i := 0; i < bc.cacheLimit; i++ {
err := bc.addBlockInfo(randomBlockInfo(t))
assert.Equal(t, nil, err)
}
for i := 0; i < bc.cacheLimit; i++ {
_, err := bc.pickFirstItem()
assert.Equal(t, nil, err)
}
assert.Equal(t, 0, bc.cacheLen())
}
func randomUint256(t *testing.T) util.Uint256 {
rand32 := make([]byte, 32)
rand.Read(rand32)
u, err := util.Uint256DecodeBytes(rand32)
assert.Equal(t, nil, err)
return u
}
func randomBlockInfo(t *testing.T) BlockInfo {
return BlockInfo{
randomUint256(t),
rand.Uint32(),
}
}

227
pkg/peermgr/peermgr.go Normal file
View file

@ -0,0 +1,227 @@
package peermgr
import (
"errors"
"fmt"
"sync"
"github.com/CityOfZion/neo-go/pkg/wire/command"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
const (
// blockCacheLimit is the maximum amount of pending requests that the cache can hold
pendingBlockCacheLimit = 20
//peerBlockCacheLimit is the maximum amount of inflight blocks that a peer can
// have, before they are flagged as busy
peerBlockCacheLimit = 1
)
var (
//ErrNoAvailablePeers is returned when a request for data from a peer is invoked
// but there are no available peers to request data from
ErrNoAvailablePeers = errors.New("there are no available peers to interact with")
// ErrUnknownPeer is returned when a peer that the peer manager does not know about
// sends a message to this node
ErrUnknownPeer = errors.New("this peer has not been registered with the peer manager")
)
//mPeer represents a peer that is managed by the peer manager
type mPeer interface {
Disconnect()
RequestBlocks([]util.Uint256) error
RequestHeaders(util.Uint256) error
NotifyDisconnect()
}
type peerstats struct {
// when a peer is sent a blockRequest
// the peermanager will track this using this blockCache
blockCache *blockCache
// all other requests will be tracked using the requests map
requests map[command.Type]bool
}
//PeerMgr manages all peers that the node is connected to
type PeerMgr struct {
pLock sync.RWMutex
peers map[mPeer]peerstats
requestCache *blockCache
}
//New returns a new peermgr object
func New() *PeerMgr {
return &PeerMgr{
peers: make(map[mPeer]peerstats),
requestCache: newBlockCache(pendingBlockCacheLimit),
}
}
// AddPeer adds a peer to the list of managed peers
func (pmgr *PeerMgr) AddPeer(peer mPeer) {
pmgr.pLock.Lock()
defer pmgr.pLock.Unlock()
if _, exists := pmgr.peers[peer]; exists {
return
}
pmgr.peers[peer] = peerstats{
requests: make(map[command.Type]bool),
blockCache: newBlockCache(peerBlockCacheLimit),
}
go pmgr.onDisconnect(peer)
}
//MsgReceived notifies the peer manager that we have received a
// message from a peer
func (pmgr *PeerMgr) MsgReceived(peer mPeer, cmd command.Type) error {
pmgr.pLock.Lock()
defer pmgr.pLock.Unlock()
// if peer was unknown then disconnect
val, ok := pmgr.peers[peer]
if !ok {
go func() {
peer.NotifyDisconnect()
}()
peer.Disconnect()
return ErrUnknownPeer
}
val.requests[cmd] = false
return nil
}
//BlockMsgReceived notifies the peer manager that we have received a
// block message from a peer
func (pmgr *PeerMgr) BlockMsgReceived(peer mPeer, bi BlockInfo) error {
// if peer was unknown then disconnect
val, ok := pmgr.peers[peer]
if !ok {
go func() {
peer.NotifyDisconnect()
}()
peer.Disconnect()
return ErrUnknownPeer
}
// // remove item from the peersBlock cache
err := val.blockCache.removeHash(bi.BlockHash)
if err != nil {
return err
}
// check if cache empty, if so then return
if pmgr.requestCache.cacheLen() == 0 {
return nil
}
// Try to clean an item from the pendingBlockCache, a peer has just finished serving a block request
cachedBInfo, err := pmgr.requestCache.pickFirstItem()
if err != nil {
return err
}
return pmgr.blockCallPeer(cachedBInfo, func(p mPeer) error {
return p.RequestBlocks([]util.Uint256{cachedBInfo.BlockHash})
})
}
// Len returns the amount of peers that the peer manager
//currently knows about
func (pmgr *PeerMgr) Len() int {
pmgr.pLock.Lock()
defer pmgr.pLock.Unlock()
return len(pmgr.peers)
}
// RequestBlock will request a block from the most
// available peer. Then update it's stats, so we know that
// this peer is busy
func (pmgr *PeerMgr) RequestBlock(bi BlockInfo) error {
pmgr.pLock.Lock()
defer pmgr.pLock.Unlock()
err := pmgr.blockCallPeer(bi, func(p mPeer) error {
return p.RequestBlocks([]util.Uint256{bi.BlockHash})
})
if err == ErrNoAvailablePeers {
return pmgr.requestCache.addBlockInfo(bi)
}
return err
}
// RequestHeaders will request a headers from the most available peer.
func (pmgr *PeerMgr) RequestHeaders(hash util.Uint256) error {
pmgr.pLock.Lock()
defer pmgr.pLock.Unlock()
return pmgr.callPeerForCmd(command.Headers, func(p mPeer) error {
return p.RequestHeaders(hash)
})
}
func (pmgr *PeerMgr) callPeerForCmd(cmd command.Type, f func(p mPeer) error) error {
for peer, stats := range pmgr.peers {
if !stats.requests[cmd] {
stats.requests[cmd] = true
return f(peer)
}
}
return ErrNoAvailablePeers
}
func (pmgr *PeerMgr) blockCallPeer(bi BlockInfo, f func(p mPeer) error) error {
for peer, stats := range pmgr.peers {
if stats.blockCache.cacheLen() < peerBlockCacheLimit {
err := stats.blockCache.addBlockInfo(bi)
if err != nil {
return err
}
return f(peer)
}
}
return ErrNoAvailablePeers
}
func (pmgr *PeerMgr) onDisconnect(p mPeer) {
// Blocking until peer is disconnected
p.NotifyDisconnect()
pmgr.pLock.Lock()
defer func() {
delete(pmgr.peers, p)
pmgr.pLock.Unlock()
}()
// Add all of peers outstanding block requests into
// the peer managers pendingBlockRequestCache
val, ok := pmgr.peers[p]
if !ok {
return
}
pendingRequests, err := val.blockCache.pickAllItems()
if err != nil {
fmt.Println(err.Error())
return
}
err = pmgr.requestCache.addBlockInfos(pendingRequests)
if err != nil {
fmt.Println(err.Error())
return
}
}

201
pkg/peermgr/peermgr_test.go Normal file
View file

@ -0,0 +1,201 @@
package peermgr
import (
"testing"
"github.com/CityOfZion/neo-go/pkg/wire/command"
"github.com/CityOfZion/neo-go/pkg/wire/util"
"github.com/stretchr/testify/assert"
)
type peer struct {
quit chan bool
nonce int
disconnected bool
blockRequested int
headersRequested int
}
func (p *peer) Disconnect() {
p.disconnected = true
p.quit <- true
}
func (p *peer) RequestBlocks([]util.Uint256) error {
p.blockRequested++
return nil
}
func (p *peer) RequestHeaders(util.Uint256) error {
p.headersRequested++
return nil
}
func (p *peer) NotifyDisconnect() {
<-p.quit
}
func TestAddPeer(t *testing.T) {
pmgr := New()
peerA := &peer{nonce: 1}
peerB := &peer{nonce: 2}
peerC := &peer{nonce: 3}
pmgr.AddPeer(peerA)
pmgr.AddPeer(peerB)
pmgr.AddPeer(peerC)
pmgr.AddPeer(peerC)
assert.Equal(t, 3, pmgr.Len())
}
func TestRequestBlocks(t *testing.T) {
pmgr := New()
peerA := &peer{nonce: 1}
peerB := &peer{nonce: 2}
peerC := &peer{nonce: 3}
pmgr.AddPeer(peerA)
pmgr.AddPeer(peerB)
pmgr.AddPeer(peerC)
firstBlock := randomBlockInfo(t)
err := pmgr.RequestBlock(firstBlock)
assert.Nil(t, err)
secondBlock := randomBlockInfo(t)
err = pmgr.RequestBlock(secondBlock)
assert.Nil(t, err)
thirdBlock := randomBlockInfo(t)
err = pmgr.RequestBlock(thirdBlock)
assert.Nil(t, err)
// Since the peer manager did not get a MsgReceived
// in between the block requests
// a request should be sent to all peers
// This is only true, if peerBlockCacheLimit == 1
assert.Equal(t, 1, peerA.blockRequested)
assert.Equal(t, 1, peerB.blockRequested)
assert.Equal(t, 1, peerC.blockRequested)
// Since the peer manager still has not received a MsgReceived
// another call to request blocks, will add the request to the cache
// and return a nil err
fourthBlock := randomBlockInfo(t)
err = pmgr.RequestBlock(fourthBlock)
assert.Equal(t, nil, err)
assert.Equal(t, 1, pmgr.requestCache.cacheLen())
// If we tell the peer manager that we have received a block
// it will check the cache for any pending requests and send a block request if there are any.
// The request will go to the peer who sent back the block corresponding to the first hash
// since the other two peers are still busy with their block requests
peer := findPeerwithHash(t, pmgr, firstBlock.BlockHash)
err = pmgr.BlockMsgReceived(peer, firstBlock)
assert.Nil(t, err)
totalRequests := peerA.blockRequested + peerB.blockRequested + peerC.blockRequested
assert.Equal(t, 4, totalRequests)
// // cache should be empty now
assert.Equal(t, 0, pmgr.requestCache.cacheLen())
}
// The peer manager does not tell you what peer was sent a particular block request
// For testing purposes, the following function will find that peer
func findPeerwithHash(t *testing.T, pmgr *PeerMgr, blockHash util.Uint256) mPeer {
for peer, stats := range pmgr.peers {
_, err := stats.blockCache.findHash(blockHash)
if err == nil {
return peer
}
}
assert.Fail(t, "cannot find a peer with that hash")
return nil
}
func TestRequestHeaders(t *testing.T) {
pmgr := New()
peerA := &peer{nonce: 1}
peerB := &peer{nonce: 2}
peerC := &peer{nonce: 3}
pmgr.AddPeer(peerA)
pmgr.AddPeer(peerB)
pmgr.AddPeer(peerC)
err := pmgr.RequestHeaders(util.Uint256{})
assert.Nil(t, err)
err = pmgr.RequestHeaders(util.Uint256{})
assert.Nil(t, err)
err = pmgr.RequestHeaders(util.Uint256{})
assert.Nil(t, err)
// Since the peer manager did not get a MsgReceived
// in between the header requests
// a request should be sent to all peers
assert.Equal(t, 1, peerA.headersRequested)
assert.Equal(t, 1, peerB.headersRequested)
assert.Equal(t, 1, peerC.headersRequested)
// Since the peer manager still has not received a MsgReceived
// another call to request header, will return a NoAvailablePeerError
err = pmgr.RequestHeaders(util.Uint256{})
assert.Equal(t, ErrNoAvailablePeers, err)
// If we tell the peer manager that peerA has given us a block
// then send another BlockRequest. It will go to peerA
// since the other two peers are still busy with their
// block requests
err = pmgr.MsgReceived(peerA, command.Headers)
assert.Nil(t, err)
err = pmgr.RequestHeaders(util.Uint256{})
assert.Nil(t, err)
assert.Equal(t, 2, peerA.headersRequested)
assert.Equal(t, 1, peerB.headersRequested)
assert.Equal(t, 1, peerC.headersRequested)
}
func TestUnknownPeer(t *testing.T) {
pmgr := New()
unknownPeer := &peer{
disconnected: false,
quit: make(chan bool),
}
err := pmgr.MsgReceived(unknownPeer, command.Headers)
assert.Equal(t, true, unknownPeer.disconnected)
assert.Equal(t, ErrUnknownPeer, err)
}
func TestNotifyDisconnect(t *testing.T) {
pmgr := New()
peerA := &peer{
nonce: 1,
quit: make(chan bool),
}
pmgr.AddPeer(peerA)
if pmgr.Len() != 1 {
t.Fail()
}
peerA.Disconnect()
if pmgr.Len() != 0 {
t.Fail()
}
}

View file

@ -1,20 +0,0 @@
package pubsub
// EventType is an enum
// representing the types of messages we can subscribe to
type EventType int
const (
// NewBlock is called When blockchain connects a new block, it will emit an NewBlock Event
NewBlock EventType = iota
// BadBlock is called When blockchain declines a block, it will emit a new block event
BadBlock
// BadHeader is called When blockchain rejects a Header, it will emit this event
BadHeader
)
// Event represents a new Event that a subscriber can listen to
type Event struct {
Type EventType // E.g. event.NewBlock
data []byte // Raw information
}

View file

@ -1,21 +0,0 @@
package pubsub
// Publisher sends events to subscribers
type Publisher struct {
subs []Subscriber
}
// Send iterates over each subscriber and checks
// if they are interested in the Event
// By looking at their topics, if they are then
// the event is emitted to them
func (p *Publisher) Send(e Event) error {
for _, sub := range p.subs {
for _, topic := range sub.Topics() {
if e.Type == topic {
sub.Emit(e)
}
}
}
return nil
}

View file

@ -1,7 +0,0 @@
package pubsub
// Subscriber will listen for Events from publishers
type Subscriber interface {
Topics() []EventType
Emit(Event)
}

7
pkg/server/addrmgr.go Normal file
View file

@ -0,0 +1,7 @@
package server
// etAddress will return a viable address to connect to
// Currently it is hardcoded to be one neo node until address manager is implemented
func (s *Server) getAddress() (string, error) {
return "seed1.ngd.network:10333", nil
}

15
pkg/server/chain.go Normal file
View file

@ -0,0 +1,15 @@
package server
import (
"github.com/CityOfZion/neo-go/pkg/chain"
"github.com/CityOfZion/neo-go/pkg/database"
"github.com/CityOfZion/neo-go/pkg/wire/protocol"
)
func setupChain(db database.Database, net protocol.Magic) (*chain.Chain, error) {
chain, err := chain.New(db, net)
if err != nil {
return nil, err
}
return chain, nil
}

47
pkg/server/connmgr.go Normal file
View file

@ -0,0 +1,47 @@
package server
import (
"fmt"
"net"
"strconv"
"github.com/CityOfZion/neo-go/pkg/connmgr"
"github.com/CityOfZion/neo-go/pkg/peer"
iputils "github.com/CityOfZion/neo-go/pkg/wire/util/ip"
)
func setupConnManager(s *Server, port uint16) *connmgr.Connmgr {
cfg := connmgr.Config{
GetAddress: s.getAddress,
OnAccept: s.onAccept,
OnConnection: s.onConnection,
AddressPort: iputils.GetLocalIP().String() + ":" + strconv.FormatUint(uint64(port), 10),
}
return connmgr.New(cfg)
}
func (s *Server) onConnection(conn net.Conn, addr string) {
fmt.Println("We have connected successfully to: ", addr)
p := peer.NewPeer(conn, false, *s.peerCfg)
err := p.Run()
if err != nil {
fmt.Println("Error running peer" + err.Error())
return
}
s.pmg.AddPeer(p)
}
func (s *Server) onAccept(conn net.Conn) {
fmt.Println("A peer with address: ", conn.RemoteAddr().String(), "has connect to us")
p := peer.NewPeer(conn, true, *s.peerCfg)
err := p.Run()
if err != nil {
fmt.Println("Error running peer" + err.Error())
return
}
s.pmg.AddPeer(p)
}

14
pkg/server/database.go Normal file
View file

@ -0,0 +1,14 @@
package server
import (
"github.com/CityOfZion/neo-go/pkg/database"
"github.com/CityOfZion/neo-go/pkg/wire/protocol"
)
func setupDatabase(net protocol.Magic) (database.Database, error) {
db, err := database.New(net.String())
if err != nil {
return nil, err
}
return db, nil
}

23
pkg/server/peerconfig.go Normal file
View file

@ -0,0 +1,23 @@
package server
import (
"math/rand"
"github.com/CityOfZion/neo-go/pkg/peer"
"github.com/CityOfZion/neo-go/pkg/wire/protocol"
)
func setupPeerConfig(s *Server, port uint16, net protocol.Magic) *peer.LocalConfig {
return &peer.LocalConfig{
Net: net,
UserAgent: "NEO-GO",
Services: protocol.NodePeerService,
Nonce: rand.Uint32(),
ProtocolVer: 0,
Relay: false,
Port: port,
StartHeight: s.chain.CurrentHeight,
OnHeader: s.onHeader,
OnBlock: s.onBlock,
}
}

9
pkg/server/peermgr.go Normal file
View file

@ -0,0 +1,9 @@
package server
import (
"github.com/CityOfZion/neo-go/pkg/peermgr"
)
func setupPeerManager() *peermgr.PeerMgr {
return peermgr.New()
}

117
pkg/server/server.go Normal file
View file

@ -0,0 +1,117 @@
package server
import (
"fmt"
"github.com/CityOfZion/neo-go/pkg/peermgr"
"github.com/CityOfZion/neo-go/pkg/chain"
"github.com/CityOfZion/neo-go/pkg/connmgr"
"github.com/CityOfZion/neo-go/pkg/peer"
"github.com/CityOfZion/neo-go/pkg/syncmgr"
"github.com/CityOfZion/neo-go/pkg/database"
"github.com/CityOfZion/neo-go/pkg/wire/protocol"
)
// Server orchestrates all of the modules
type Server struct {
net protocol.Magic
stopCh chan error
// Modules
db database.Database
smg *syncmgr.Syncmgr
cmg *connmgr.Connmgr
pmg *peermgr.PeerMgr
chain *chain.Chain
peerCfg *peer.LocalConfig
}
//New creates a new server object for a particular network and sets up each module
func New(net protocol.Magic, port uint16) (*Server, error) {
s := &Server{
net: net,
stopCh: make(chan error, 0),
}
// Setup database
db, err := setupDatabase(net)
if err != nil {
return nil, err
}
s.db = db
// setup peermgr
peermgr := setupPeerManager()
s.pmg = peermgr
// Setup chain
chain, err := setupChain(db, net)
if err != nil {
return nil, err
}
s.chain = chain
// Setup sync manager
syncmgr, err := setupSyncManager(s)
if err != nil {
return nil, err
}
s.smg = syncmgr
// Setup connection manager
connmgr := setupConnManager(s, port)
s.cmg = connmgr
// Setup peer config
peerCfg := setupPeerConfig(s, port, net)
s.peerCfg = peerCfg
return s, nil
}
// Run starts the daemon by connecting to previously nodes or connectng to seed nodes.
// This should be called once all modules have been setup
func (s *Server) Run() error {
fmt.Println("Server is starting up")
// start the connmgr
err := s.cmg.Run()
if err != nil {
return err
}
// Attempt to connect to a peer
err = s.cmg.NewRequest()
if err != nil {
return err
}
// Request header to start synchronisation
bestHeader, err := s.chain.Db.GetLastHeader()
if err != nil {
return err
}
err = s.pmg.RequestHeaders(bestHeader.Hash)
if err != nil {
return err
}
fmt.Println("Server Successfully started")
return s.wait()
}
func (s *Server) wait() error {
err := <-s.stopCh
return err
}
// Stop stops the server
func (s *Server) Stop(err error) error {
fmt.Println("Server is shutting down")
s.stopCh <- err
return nil
}

110
pkg/server/syncmgr.go Normal file
View file

@ -0,0 +1,110 @@
package server
import (
"encoding/binary"
"github.com/CityOfZion/neo-go/pkg/peermgr"
"github.com/CityOfZion/neo-go/pkg/peer"
"github.com/CityOfZion/neo-go/pkg/syncmgr"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
func setupSyncManager(s *Server) (*syncmgr.Syncmgr, error) {
cfg := &syncmgr.Config{
ProcessBlock: s.processBlock,
ProcessHeaders: s.processHeaders,
RequestBlock: s.requestBlock,
RequestHeaders: s.requestHeaders,
GetNextBlockHash: s.getNextBlockHash,
AskForNewBlocks: s.askForNewBlocks,
FetchHeadersAgain: s.fetchHeadersAgain,
FetchBlockAgain: s.fetchBlockAgain,
}
// Add nextBlockIndex in syncmgr
lastBlock, err := s.chain.Db.GetLastBlock()
if err != nil {
return nil, err
}
nextBlockIndex := lastBlock.Index + 1
return syncmgr.New(cfg, nextBlockIndex), nil
}
func (s *Server) onHeader(peer *peer.Peer, hdrsMessage *payload.HeadersMessage) {
s.pmg.MsgReceived(peer, hdrsMessage.Command())
s.smg.OnHeader(peer, hdrsMessage)
}
func (s *Server) onBlock(peer *peer.Peer, blockMsg *payload.BlockMessage) {
s.pmg.BlockMsgReceived(peer, peermgr.BlockInfo{
BlockHash: blockMsg.Hash,
BlockIndex: blockMsg.Index,
})
s.smg.OnBlock(peer, blockMsg)
}
func (s *Server) processBlock(block payload.Block) error {
return s.chain.ProcessBlock(block)
}
func (s *Server) processHeaders(hdrs []*payload.BlockBase) error {
return s.chain.ProcessHeaders(hdrs)
}
func (s *Server) requestHeaders(hash util.Uint256) error {
return s.pmg.RequestHeaders(hash)
}
func (s *Server) requestBlock(hash util.Uint256, index uint32) error {
return s.pmg.RequestBlock(peermgr.BlockInfo{
BlockHash: hash,
BlockIndex: index,
})
}
// getNextBlockHash searches the database for the blockHash
// that is the height above our best block. The hash will be taken from a header.
func (s *Server) getNextBlockHash() (util.Uint256, error) {
bestBlock, err := s.chain.Db.GetLastBlock()
if err != nil {
// Panic!
// XXX: One alternative, is to get the network, erase the database and then start again from scratch.
// This should never happen. The latest block will always be atleast the genesis block
panic("could not get best block from database" + err.Error())
}
index := make([]byte, 4)
binary.BigEndian.PutUint32(index, bestBlock.Index+1)
hdr, err := s.chain.Db.GetHeaderFromHeight(index)
if err != nil {
return util.Uint256{}, err
}
return hdr.Hash, nil
}
func (s *Server) getBestBlockHash() (util.Uint256, error) {
return util.Uint256{}, nil
}
func (s *Server) askForNewBlocks() {
// send a getblocks message with the latest block saved
// when we receive something then send get data
}
func (s *Server) fetchHeadersAgain(util.Uint256) error {
return nil
}
func (s *Server) fetchBlockAgain(util.Uint256) error {
return nil
}

61
pkg/syncmgr/blockmode.go Normal file
View file

@ -0,0 +1,61 @@
package syncmgr
import (
"github.com/CityOfZion/neo-go/pkg/chain"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
)
// blockModeOnBlock is called when the sync manager is block mode
// and receives a block.
func (s *Syncmgr) blockModeOnBlock(peer SyncPeer, block payload.Block) error {
// Check if it is a future block
// XXX: since we are storing blocks in memory, we do not want to store blocks
// from the tip
if block.Index > s.nextBlockIndex+2000 {
return nil
}
if block.Index > s.nextBlockIndex {
s.addToBlockPool(block)
return nil
}
// Process Block
err := s.processBlock(block)
if err != nil && err != chain.ErrBlockAlreadyExists {
return s.cfg.FetchBlockAgain(block.Hash)
}
// Check the block pool
err = s.checkPool()
if err != nil {
return err
}
// Check if blockhashReceived == the header hash from last get headers this node performed
// if not then increment and request next block
if s.headerHash != block.Hash {
nextHash, err := s.cfg.GetNextBlockHash()
if err != nil {
return err
}
return s.cfg.RequestBlock(nextHash, block.Index)
}
// If we are caught up then go into normal mode
diff := peer.Height() - block.Index
if diff <= cruiseHeight {
s.syncmode = normalMode
s.timer.Reset(blockTimer)
return nil
}
// If not then we go back into headersMode and request more headers.
s.syncmode = headersMode
return s.cfg.RequestHeaders(block.Hash)
}
func (s *Syncmgr) blockModeOnHeaders(peer SyncPeer, hdrs []*payload.BlockBase) error {
// We ignore headers when in this mode
return nil
}

57
pkg/syncmgr/blockpool.go Normal file
View file

@ -0,0 +1,57 @@
package syncmgr
import (
"sort"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
)
func (s *Syncmgr) addToBlockPool(newBlock payload.Block) {
s.poolLock.Lock()
defer s.poolLock.Unlock()
for _, block := range s.blockPool {
if block.Index == newBlock.Index {
return
}
}
s.blockPool = append(s.blockPool, newBlock)
// sort slice using block index
sort.Slice(s.blockPool, func(i, j int) bool {
return s.blockPool[i].Index < s.blockPool[j].Index
})
}
func (s *Syncmgr) checkPool() error {
// Assuming that the blocks are sorted in order
var indexesToRemove = -1
s.poolLock.Lock()
defer func() {
// removes all elements before this index, including the element at this index
s.blockPool = s.blockPool[indexesToRemove+1:]
s.poolLock.Unlock()
}()
// loop iterates through the cache, processing any
// blocks that can be added to the chain
for i, block := range s.blockPool {
if s.nextBlockIndex != block.Index {
break
}
// Save this block and save the indice location so we can remove it, when we defer
err := s.processBlock(block)
if err != nil {
return err
}
indexesToRemove = i
}
return nil
}

View file

@ -0,0 +1,42 @@
package syncmgr
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddBlockPoolFlush(t *testing.T) {
syncmgr, _ := setupSyncMgr(blockMode, 10)
blockMessage := randomBlockMessage(t, 11)
peer := &mockPeer{
height: 100,
}
// Since the block has Index 11 and the sync manager needs the block with index 10
// This block will be added to the blockPool
err := syncmgr.OnBlock(peer, blockMessage)
assert.Nil(t, err)
assert.Equal(t, 1, len(syncmgr.blockPool))
// The sync manager is still looking for the block at height 10
// Since this block is at height 12, it will be added to the block pool
blockMessage = randomBlockMessage(t, 12)
err = syncmgr.OnBlock(peer, blockMessage)
assert.Nil(t, err)
assert.Equal(t, 2, len(syncmgr.blockPool))
// This is the block that the sync manager was waiting for
// It should process this block, the check the pool for the next set of blocks
blockMessage = randomBlockMessage(t, 10)
err = syncmgr.OnBlock(peer, blockMessage)
assert.Nil(t, err)
assert.Equal(t, 0, len(syncmgr.blockPool))
// Since we processed 3 blocks and the sync manager started
//looking for block with index 10. The syncmananger should be looking for
// the block with index 13
assert.Equal(t, uint32(13), syncmgr.nextBlockIndex)
}

44
pkg/syncmgr/config.go Normal file
View file

@ -0,0 +1,44 @@
package syncmgr
import (
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
// Config is the configuration file for the sync manager
type Config struct {
// Chain functions
ProcessBlock func(block payload.Block) error
ProcessHeaders func(hdrs []*payload.BlockBase) error
// RequestHeaders will send a getHeaders request
// with the hash passed in as a parameter
RequestHeaders func(hash util.Uint256) error
//RequestBlock will send a getdata request for the block
// with the hash passed as a parameter
RequestBlock func(hash util.Uint256, index uint32) error
// GetNextBlockHash returns the block hash of the header infront of thr block
// at the tip of this nodes chain. This assumes that the node is not in sync
GetNextBlockHash func() (util.Uint256, error)
// AskForNewBlocks will send out a message to the network
// asking for new blocks
AskForNewBlocks func()
// FetchHeadersAgain is called when a peer has provided headers that have not
// validated properly. We pass in the hash of the first header
FetchHeadersAgain func(util.Uint256) error
// FetchHeadersAgain is called when a peer has provided a block that has not
// validated properly. We pass in the hash of the block
FetchBlockAgain func(util.Uint256) error
}
// SyncPeer represents a peer on the network
// that this node can sync with
type SyncPeer interface {
Height() uint32
}

42
pkg/syncmgr/headermode.go Normal file
View file

@ -0,0 +1,42 @@
package syncmgr
import (
"github.com/CityOfZion/neo-go/pkg/chain"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
)
// headersModeOnHeaders is called when the sync manager is headers mode
// and receives a header.
func (s *Syncmgr) headersModeOnHeaders(peer SyncPeer, hdrs []*payload.BlockBase) error {
// If we are in Headers mode, then we just need to process the headers
// Note: For the un-optimised version, we move straight to blocksOnly mode
firstHash := hdrs[0].Hash
firstHdrIndex := hdrs[0].Index
err := s.cfg.ProcessHeaders(hdrs)
if err == nil {
// Update syncmgr last header
s.headerHash = hdrs[len(hdrs)-1].Hash
s.syncmode = blockMode
return s.cfg.RequestBlock(firstHash, firstHdrIndex)
}
// Check whether it is a validation error, or a database error
if _, ok := err.(*chain.ValidationError); ok {
// If we get a validation error we re-request the headers
// the method will automatically fetch from a different peer
// XXX: Add increment banScore for this peer
return s.cfg.FetchHeadersAgain(firstHash)
}
// This means it is a database error. We have no way to recover from this.
panic(err.Error())
}
// headersModeOnBlock is called when the sync manager is headers mode
// and receives a block.
func (s *Syncmgr) headersModeOnBlock(peer SyncPeer, block payload.Block) error {
// While in headers mode, ignore any blocks received
return nil
}

View file

@ -0,0 +1,113 @@
package syncmgr
import (
"crypto/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
type syncTestHelper struct {
blocksProcessed int
headersProcessed int
newBlockRequest int
headersFetchRequest int
blockFetchRequest int
err error
}
func (s *syncTestHelper) ProcessBlock(msg payload.Block) error {
s.blocksProcessed++
return s.err
}
func (s *syncTestHelper) ProcessHeaders(hdrs []*payload.BlockBase) error {
s.headersProcessed = s.headersProcessed + len(hdrs)
return s.err
}
func (s *syncTestHelper) GetNextBlockHash() (util.Uint256, error) {
return util.Uint256{}, s.err
}
func (s *syncTestHelper) AskForNewBlocks() {
s.newBlockRequest++
}
func (s *syncTestHelper) FetchHeadersAgain(util.Uint256) error {
s.headersFetchRequest++
return s.err
}
func (s *syncTestHelper) FetchBlockAgain(util.Uint256) error {
s.blockFetchRequest++
return s.err
}
func (s *syncTestHelper) RequestBlock(util.Uint256, uint32) error {
s.blockFetchRequest++
return s.err
}
func (s *syncTestHelper) RequestHeaders(util.Uint256) error {
s.headersFetchRequest++
return s.err
}
type mockPeer struct {
height uint32
}
func (p *mockPeer) Height() uint32 { return p.height }
func randomHeadersMessage(t *testing.T, num int) *payload.HeadersMessage {
var hdrs []*payload.BlockBase
for i := 0; i < num; i++ {
hash := randomUint256(t)
hdr := &payload.BlockBase{Hash: hash}
hdrs = append(hdrs, hdr)
}
hdrsMsg, err := payload.NewHeadersMessage()
assert.Nil(t, err)
hdrsMsg.Headers = hdrs
return hdrsMsg
}
func randomUint256(t *testing.T) util.Uint256 {
hash := make([]byte, 32)
_, err := rand.Read(hash)
assert.Nil(t, err)
u, err := util.Uint256DecodeBytes(hash)
assert.Nil(t, err)
return u
}
func setupSyncMgr(mode mode, nextBlockIndex uint32) (*Syncmgr, *syncTestHelper) {
helper := &syncTestHelper{}
cfg := &Config{
ProcessBlock: helper.ProcessBlock,
ProcessHeaders: helper.ProcessHeaders,
GetNextBlockHash: helper.GetNextBlockHash,
AskForNewBlocks: helper.AskForNewBlocks,
FetchHeadersAgain: helper.FetchHeadersAgain,
FetchBlockAgain: helper.FetchBlockAgain,
RequestBlock: helper.RequestBlock,
RequestHeaders: helper.RequestHeaders,
}
syncmgr := New(cfg, nextBlockIndex)
syncmgr.syncmode = mode
return syncmgr, helper
}

60
pkg/syncmgr/normalmode.go Normal file
View file

@ -0,0 +1,60 @@
package syncmgr
import (
"github.com/CityOfZion/neo-go/pkg/wire/payload"
)
func (s *Syncmgr) normalModeOnHeaders(peer SyncPeer, hdrs []*payload.BlockBase) error {
// If in normal mode, first process the headers
err := s.cfg.ProcessHeaders(hdrs)
if err != nil {
// If something went wrong with processing the headers
// Ask another peer for the headers.
//XXX: Increment banscore for this peer
return s.cfg.FetchHeadersAgain(hdrs[0].Hash)
}
lenHeaders := len(hdrs)
firstHash := hdrs[0].Hash
firstHdrIndex := hdrs[0].Index
lastHash := hdrs[lenHeaders-1].Hash
// Update syncmgr latest header
s.headerHash = lastHash
// If there are 2k headers, then ask for more headers and switch back to headers mode.
if lenHeaders == 2000 {
s.syncmode = headersMode
return s.cfg.RequestHeaders(lastHash)
}
// Ask for the corresponding block iff there is < 2k headers
// then switch to blocksMode
// Bounds state that len > 1 && len!= 2000 & maxHeadersInMessage == 2000
// This means that we have less than 2k headers
s.syncmode = blockMode
return s.cfg.RequestBlock(firstHash, firstHdrIndex)
}
// normalModeOnBlock is called when the sync manager is normal mode
// and receives a block.
func (s *Syncmgr) normalModeOnBlock(peer SyncPeer, block payload.Block) error {
// stop the timer that periodically asks for blocks
s.timer.Stop()
// process block
err := s.processBlock(block)
if err != nil {
s.timer.Reset(blockTimer)
return s.cfg.FetchBlockAgain(block.Hash)
}
diff := peer.Height() - block.Index
if diff > trailingHeight {
s.syncmode = headersMode
return s.cfg.RequestHeaders(block.Hash)
}
s.timer.Reset(blockTimer)
return nil
}

152
pkg/syncmgr/syncmgr.go Normal file
View file

@ -0,0 +1,152 @@
package syncmgr
import (
"fmt"
"sync"
"time"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
type mode uint8
// Note: this is the unoptimised version without parallel sync
// The algorithm for the unoptimsied version is simple:
// Download 2000 headers, then download the blocks for those headers
// Once those blocks are downloaded, we repeat the process again
// Until we are nomore than one block behind the tip.
// Once this happens, we switch into normal mode.
//In normal mode, we have a timer on for X seconds and ask nodes for blocks and also to doublecheck
// if we are behind once the timer runs out.
// The timer restarts whenever we receive a block.
// The parameter X should be approximately the time it takes the network to reach consensus
//blockTimer approximates to how long it takes to reach consensus and propagate
// a block in the network. Once a node has synchronised with the network, he will
// ask the network for a newblock every blockTimer
const blockTimer = 20 * time.Second
// trailingHeight indicates how many blocks the node has to be behind by
// before he switches to headersMode.
const trailingHeight = 100
// indicates how many blocks the node has to be behind by
// before he switches to normalMode and fetches blocks every X seconds.
const cruiseHeight = 0
const (
headersMode mode = 1
blockMode mode = 2
normalMode mode = 3
)
//Syncmgr keeps the node in sync with the rest of the network
type Syncmgr struct {
syncmode mode
cfg *Config
timer *time.Timer
// headerHash is the hash of the last header in the last OnHeaders message that we received.
// When receiving blocks, we can use this to determine whether the node has downloaded
// all of the blocks for the last headers messages
headerHash util.Uint256
poolLock sync.Mutex
blockPool []payload.Block
nextBlockIndex uint32
}
// New creates a new sync manager
func New(cfg *Config, nextBlockIndex uint32) *Syncmgr {
newBlockTimer := time.AfterFunc(blockTimer, func() {
cfg.AskForNewBlocks()
})
newBlockTimer.Stop()
return &Syncmgr{
syncmode: headersMode,
cfg: cfg,
timer: newBlockTimer,
nextBlockIndex: nextBlockIndex,
}
}
// OnHeader is called when the node receives a headers message
func (s *Syncmgr) OnHeader(peer SyncPeer, msg *payload.HeadersMessage) error {
// XXX(Optimisation): First check if we actually need these headers
// Check the last header in msg and then check what our latest header that was saved is
// If our latest header is above the lastHeader, then we do not save it
// We could also have that our latest header is above only some of the headers.
// In this case, we should remove the headers that we already have
if len(msg.Headers) == 0 {
// XXX: Increment banScore for this peer, for sending empty headers message
return nil
}
var err error
switch s.syncmode {
case headersMode:
err = s.headersModeOnHeaders(peer, msg.Headers)
case blockMode:
err = s.blockModeOnHeaders(peer, msg.Headers)
case normalMode:
err = s.normalModeOnHeaders(peer, msg.Headers)
default:
err = s.headersModeOnHeaders(peer, msg.Headers)
}
// XXX(Kev):The only meaningful error here would be if the peer
// we re-requested blocks from failed. In the next iteration, this will be handled
// by the peer manager, who will only return an error, if we are connected to no peers.
// Upon re-alising this, the node will then send out GetAddresses to the network and
// syncing will be resumed, once we find peers to connect to.
hdr := msg.Headers[len(msg.Headers)-1]
fmt.Printf("Finished processing headers. LastHash in set was: %s\n ", hdr.Hash.ReverseString())
return err
}
// OnBlock is called when the node receives a block
func (s *Syncmgr) OnBlock(peer SyncPeer, msg *payload.BlockMessage) error {
fmt.Printf("Block received with height %d\n", msg.Block.Index)
var err error
switch s.syncmode {
case headersMode:
err = s.headersModeOnBlock(peer, msg.Block)
case blockMode:
err = s.blockModeOnBlock(peer, msg.Block)
case normalMode:
err = s.normalModeOnBlock(peer, msg.Block)
default:
err = s.headersModeOnBlock(peer, msg.Block)
}
fmt.Printf("Processed Block with height %d\n", msg.Block.Index)
return err
}
//IsCurrent returns true if the node is currently
// synced up with the network
func (s *Syncmgr) IsCurrent() bool {
return s.syncmode == normalMode
}
func (s *Syncmgr) processBlock(block payload.Block) error {
err := s.cfg.ProcessBlock(block)
if err != nil {
return err
}
s.nextBlockIndex++
return nil
}

View file

@ -0,0 +1,97 @@
package syncmgr
import (
"testing"
"github.com/CityOfZion/neo-go/pkg/chain"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/stretchr/testify/assert"
)
func TestHeadersModeOnBlock(t *testing.T) {
syncmgr, helper := setupSyncMgr(headersMode, 0)
syncmgr.OnBlock(&mockPeer{}, randomBlockMessage(t, 0))
// In headerMode, we do nothing
assert.Equal(t, 0, helper.blocksProcessed)
}
func TestBlockModeOnBlock(t *testing.T) {
syncmgr, helper := setupSyncMgr(blockMode, 0)
syncmgr.OnBlock(&mockPeer{}, randomBlockMessage(t, 0))
// When a block is received in blockMode, it is processed
assert.Equal(t, 1, helper.blocksProcessed)
}
func TestNormalModeOnBlock(t *testing.T) {
syncmgr, helper := setupSyncMgr(normalMode, 0)
syncmgr.OnBlock(&mockPeer{}, randomBlockMessage(t, 0))
// When a block is received in normal, it is processed
assert.Equal(t, 1, helper.blocksProcessed)
}
func TestBlockModeToNormalMode(t *testing.T) {
syncmgr, _ := setupSyncMgr(blockMode, 100)
peer := &mockPeer{
height: 100,
}
blkMessage := randomBlockMessage(t, 100)
syncmgr.OnBlock(peer, blkMessage)
// We should switch to normal mode, since the block
//we received is close to the height of the peer. See cruiseHeight
assert.Equal(t, normalMode, syncmgr.syncmode)
}
func TestBlockModeStayInBlockMode(t *testing.T) {
syncmgr, _ := setupSyncMgr(blockMode, 0)
// We need our latest know hash to not be equal to the hash
// of the block we received, to stay in blockmode
syncmgr.headerHash = randomUint256(t)
peer := &mockPeer{
height: 2000,
}
blkMessage := randomBlockMessage(t, 100)
syncmgr.OnBlock(peer, blkMessage)
// We should stay in block mode, since the block we received is
// still quite far behind the peers height
assert.Equal(t, blockMode, syncmgr.syncmode)
}
func TestBlockModeAlreadyExistsErr(t *testing.T) {
syncmgr, helper := setupSyncMgr(blockMode, 100)
helper.err = chain.ErrBlockAlreadyExists
syncmgr.OnBlock(&mockPeer{}, randomBlockMessage(t, 100))
assert.Equal(t, 0, helper.blockFetchRequest)
// If we have a block already exists in blockmode, then we
// switch back to headers mode.
assert.Equal(t, headersMode, syncmgr.syncmode)
}
func randomBlockMessage(t *testing.T, height uint32) *payload.BlockMessage {
blockMessage, err := payload.NewBlockMessage()
blockMessage.BlockBase.Index = height
assert.Nil(t, err)
return blockMessage
}

View file

@ -0,0 +1,117 @@
package syncmgr
import (
"testing"
"github.com/CityOfZion/neo-go/pkg/chain"
"github.com/stretchr/testify/assert"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
func TestHeadersModeOnHeaders(t *testing.T) {
syncmgr, helper := setupSyncMgr(headersMode, 0)
syncmgr.OnHeader(&mockPeer{}, randomHeadersMessage(t, 0))
// Since there were no headers, we should have exited early and processed nothing
assert.Equal(t, 0, helper.headersProcessed)
// ProcessHeaders should have been called once to process all 100 headers
syncmgr.OnHeader(&mockPeer{}, randomHeadersMessage(t, 100))
assert.Equal(t, 100, helper.headersProcessed)
// Mode should now be blockMode
assert.Equal(t, blockMode, syncmgr.syncmode)
}
func TestBlockModeOnHeaders(t *testing.T) {
syncmgr, helper := setupSyncMgr(blockMode, 0)
// If we receive a header in blockmode, no headers will be processed
syncmgr.OnHeader(&mockPeer{}, randomHeadersMessage(t, 100))
assert.Equal(t, 0, helper.headersProcessed)
}
func TestNormalModeOnHeadersMaxHeaders(t *testing.T) {
syncmgr, helper := setupSyncMgr(normalMode, 0)
// If we receive a header in normalmode, headers will be processed
syncmgr.OnHeader(&mockPeer{}, randomHeadersMessage(t, 2000))
assert.Equal(t, 2000, helper.headersProcessed)
// Mode should now be headersMode since we received 2000 headers
assert.Equal(t, headersMode, syncmgr.syncmode)
}
// This differs from the previous function in that
//we did not receive the max amount of headers
func TestNormalModeOnHeaders(t *testing.T) {
syncmgr, helper := setupSyncMgr(normalMode, 0)
// If we receive a header in normalmode, headers will be processed
syncmgr.OnHeader(&mockPeer{}, randomHeadersMessage(t, 200))
assert.Equal(t, 200, helper.headersProcessed)
// Because we did not receive 2000 headers, we switch to blockMode
assert.Equal(t, blockMode, syncmgr.syncmode)
}
func TestLastHeaderUpdates(t *testing.T) {
syncmgr, _ := setupSyncMgr(headersMode, 0)
hdrsMessage := randomHeadersMessage(t, 200)
hdrs := hdrsMessage.Headers
lastHeader := hdrs[len(hdrs)-1]
syncmgr.OnHeader(&mockPeer{}, hdrsMessage)
// Headers are processed in headersMode
// Last header should be updated
assert.True(t, syncmgr.headerHash.Equals(lastHeader.Hash))
// Change mode to blockMode and reset lastHeader
syncmgr.syncmode = blockMode
syncmgr.headerHash = util.Uint256{}
syncmgr.OnHeader(&mockPeer{}, hdrsMessage)
// header should not be changed
assert.False(t, syncmgr.headerHash.Equals(lastHeader.Hash))
// Change mode to normalMode and reset lastHeader
syncmgr.syncmode = normalMode
syncmgr.headerHash = util.Uint256{}
syncmgr.OnHeader(&mockPeer{}, hdrsMessage)
// headers are processed in normalMode
// hash should be updated
assert.True(t, syncmgr.headerHash.Equals(lastHeader.Hash))
}
func TestHeadersModeOnHeadersErr(t *testing.T) {
syncmgr, helper := setupSyncMgr(headersMode, 0)
helper.err = &chain.ValidationError{}
syncmgr.OnHeader(&mockPeer{}, randomHeadersMessage(t, 200))
// On a validation error, we should request for another peer
// to send us these headers
assert.Equal(t, 1, helper.headersFetchRequest)
}
func TestNormalModeOnHeadersErr(t *testing.T) {
syncmgr, helper := setupSyncMgr(normalMode, 0)
helper.err = &chain.ValidationError{}
syncmgr.OnHeader(&mockPeer{}, randomHeadersMessage(t, 200))
// On a validation error, we should request for another peer
// to send us these headers
assert.Equal(t, 1, helper.headersFetchRequest)
}

View file

@ -296,7 +296,6 @@ func BoolAnd(op stack.Instruction, ctx *stack.Context, istack *stack.Invocation,
return FAULT, err return FAULT, err
} }
res := bool1.And(bool2) res := bool1.And(bool2)
ctx.Estack.Push(res) ctx.Estack.Push(res)
return NONE, nil return NONE, nil

File diff suppressed because one or more lines are too long

View file

@ -50,7 +50,7 @@ func TestAddAndEncodeHeaders(t *testing.T) {
err := msgHeaders.Headers[0].createHash() err := msgHeaders.Headers[0].createHash()
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
// Hash being correct, automatically verifies that the fields are encoded properly // Hash being correct, automatically verifies that the fields are encoded properly
assert.Equal(t, "f3c4ec44c07eccbda974f1ee34bc6654ab6d3f22cd89c2e5c593a16d6cc7e6e8", msgHeaders.Headers[0].Hash.String()) assert.Equal(t, "f3c4ec44c07eccbda974f1ee34bc6654ab6d3f22cd89c2e5c593a16d6cc7e6e8", msgHeaders.Headers[0].Hash.ReverseString())
} }
@ -70,7 +70,7 @@ func TestEncodeDecode(t *testing.T) {
header := headerMsg.Headers[0] header := headerMsg.Headers[0]
err = header.createHash() err = header.createHash()
assert.Equal(t, "f3c4ec44c07eccbda974f1ee34bc6654ab6d3f22cd89c2e5c593a16d6cc7e6e8", header.Hash.String()) assert.Equal(t, "f3c4ec44c07eccbda974f1ee34bc6654ab6d3f22cd89c2e5c593a16d6cc7e6e8", header.Hash.ReverseString())
buf := new(bytes.Buffer) buf := new(bytes.Buffer)

View file

@ -24,14 +24,13 @@ func (a *Attribute) Encode(bw *util.BinWriter) {
} }
bw.Write(uint8(a.Usage)) bw.Write(uint8(a.Usage))
if a.Usage == DescriptionURL || a.Usage == Vote || (a.Usage >= Hash1 && a.Usage <= Hash15) { if a.Usage == ContractHash || a.Usage == Vote || (a.Usage >= Hash1 && a.Usage <= Hash15) {
bw.Write(a.Data[:32]) bw.Write(a.Data[:32])
} else if a.Usage == Script {
bw.Write(a.Data[:20])
} else if a.Usage == ECDH02 || a.Usage == ECDH03 { } else if a.Usage == ECDH02 || a.Usage == ECDH03 {
bw.Write(a.Data[1:33]) bw.Write(a.Data[1:33])
} else if a.Usage == CertURL || a.Usage == DescriptionURL || a.Usage == Description || a.Usage >= Remark { } else if a.Usage == Script {
bw.Write(a.Data[:20])
} else if a.Usage == DescriptionURL || a.Usage == Description || a.Usage >= Remark {
bw.VarUint(uint64(len(a.Data))) bw.VarUint(uint64(len(a.Data)))
bw.Write(a.Data) bw.Write(a.Data)
} else { } else {
@ -43,17 +42,16 @@ func (a *Attribute) Encode(bw *util.BinWriter) {
// Decode decodes the binary reader into an Attribute object // Decode decodes the binary reader into an Attribute object
func (a *Attribute) Decode(br *util.BinReader) { func (a *Attribute) Decode(br *util.BinReader) {
br.Read(&a.Usage) br.Read(&a.Usage)
if a.Usage == DescriptionURL || a.Usage == Vote || a.Usage >= Hash1 && a.Usage <= Hash15 { if a.Usage == ContractHash || a.Usage == Vote || a.Usage >= Hash1 && a.Usage <= Hash15 {
a.Data = make([]byte, 32) a.Data = make([]byte, 32)
br.Read(&a.Data) br.Read(&a.Data)
} else if a.Usage == Script {
a.Data = make([]byte, 20)
br.Read(&a.Data)
} else if a.Usage == ECDH02 || a.Usage == ECDH03 { } else if a.Usage == ECDH02 || a.Usage == ECDH03 {
a.Data = make([]byte, 32) a.Data = make([]byte, 32)
br.Read(&a.Data) br.Read(&a.Data)
} else if a.Usage == CertURL || a.Usage == DescriptionURL || a.Usage == Description || a.Usage >= Remark { } else if a.Usage == Script {
a.Data = make([]byte, 20)
br.Read(&a.Data)
} else if a.Usage == DescriptionURL || a.Usage == Description || a.Usage >= Remark {
lenData := br.VarUint() lenData := br.VarUint()
a.Data = make([]byte, lenData) a.Data = make([]byte, lenData)
br.Read(&a.Data) br.Read(&a.Data)

View file

@ -18,11 +18,8 @@ type decodeExclusiveFields func(br *util.BinReader)
type Transactioner interface { type Transactioner interface {
Encode(w io.Writer) error Encode(w io.Writer) error
Decode(r io.Reader) error Decode(r io.Reader) error
BaseTx() *Base
ID() (util.Uint256, error) ID() (util.Uint256, error)
Bytes() []byte
UTXOs() []*Output
TXOs() []*Input
Witness() []*Witness
} }
// Base transaction is the template for all other transactions // Base transaction is the template for all other transactions
@ -199,17 +196,7 @@ func (b *Base) Bytes() []byte {
return buf.Bytes() return buf.Bytes()
} }
// UTXOs returns the outputs in the tx // BaseTx returns the Base object in a transaction
func (b *Base) UTXOs() []*Output { func (b *Base) BaseTx() *Base {
return b.Outputs return b
}
// TXOs returns the inputs in the tx
func (b *Base) TXOs() []*Input {
return b.Inputs
}
// Witness returns the witnesses in the tx
func (b *Base) Witness() []*Witness {
return b.Witnesses
} }

View file

@ -26,9 +26,9 @@ func TestEncodeDecodeClaim(t *testing.T) {
assert.Equal(t, 1, int(len(c.Claims))) assert.Equal(t, 1, int(len(c.Claims)))
claim := c.Claims[0] claim := c.Claims[0]
assert.Equal(t, "497037a4c5e0a9ea1721e06f9d5e9aec183d11f2824ece93285729370f3a1baf", claim.PrevHash.String()) assert.Equal(t, "497037a4c5e0a9ea1721e06f9d5e9aec183d11f2824ece93285729370f3a1baf", claim.PrevHash.ReverseString())
assert.Equal(t, uint16(0), claim.PrevIndex) assert.Equal(t, uint16(0), claim.PrevIndex)
assert.Equal(t, "abf142faf539c340e42722b5b34b505cf4fd73185fed775784e37c2c5ef1b866", c.Hash.String()) assert.Equal(t, "abf142faf539c340e42722b5b34b505cf4fd73185fed775784e37c2c5ef1b866", c.Hash.ReverseString())
// Encode // Encode
buf := new(bytes.Buffer) buf := new(bytes.Buffer)

View file

@ -27,12 +27,12 @@ func TestEncodeDecodeContract(t *testing.T) {
input := c.Inputs[0] input := c.Inputs[0]
assert.Equal(t, "eec17cc828d6ede932b57e4eaf79c2591151096a7825435cd67f498f9fa98d88", input.PrevHash.String()) assert.Equal(t, "eec17cc828d6ede932b57e4eaf79c2591151096a7825435cd67f498f9fa98d88", input.PrevHash.ReverseString())
assert.Equal(t, 0, int(input.PrevIndex)) assert.Equal(t, 0, int(input.PrevIndex))
assert.Equal(t, int64(70600000000), c.Outputs[0].Amount) assert.Equal(t, int64(70600000000), c.Outputs[0].Amount)
assert.Equal(t, "c56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b", c.Outputs[0].AssetID.String()) assert.Equal(t, "c56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b", c.Outputs[0].AssetID.ReverseString())
assert.Equal(t, "a8666b4830229d6a1a9b80f6088059191c122d2b", c.Outputs[0].ScriptHash.String()) assert.Equal(t, "a8666b4830229d6a1a9b80f6088059191c122d2b", c.Outputs[0].ScriptHash.String())
assert.Equal(t, "bdf6cc3b9af12a7565bda80933a75ee8cef1bc771d0d58effc08e4c8b436da79", c.Hash.String()) assert.Equal(t, "bdf6cc3b9af12a7565bda80933a75ee8cef1bc771d0d58effc08e4c8b436da79", c.Hash.ReverseString())
// Encode // Encode
buf := new(bytes.Buffer) buf := new(bytes.Buffer)

View file

@ -26,5 +26,5 @@ func TestEncodeDecodeEnrollment(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, rawtx, hex.EncodeToString(buf.Bytes())) assert.Equal(t, rawtx, hex.EncodeToString(buf.Bytes()))
assert.Equal(t, "988832f693785dcbcb8d5a0e9d5d22002adcbfb1eb6bbeebf8c494fff580e147", enroll.Hash.String()) assert.Equal(t, "988832f693785dcbcb8d5a0e9d5d22002adcbfb1eb6bbeebf8c494fff580e147", enroll.Hash.ReverseString())
} }

View file

@ -34,7 +34,7 @@ func TestEncodeDecodeInvoc(t *testing.T) {
assert.Equal(t, "31363a30373a3032203a2030333366616431392d643638322d343035382d626437662d313563393331323434336538", hex.EncodeToString(attr2.Data)) assert.Equal(t, "31363a30373a3032203a2030333366616431392d643638322d343035382d626437662d313563393331323434336538", hex.EncodeToString(attr2.Data))
assert.Equal(t, "050034e23004141ad842821c7341d5a32b17d7177a1750d30014ca14628c9e5bc6a9346ca6bcdf050ceabdeb2bdc774953c1087472616e736665726703e1df72015bdef1a1b9567d4700635f23b1f406f1", hex.EncodeToString(i.Script)) assert.Equal(t, "050034e23004141ad842821c7341d5a32b17d7177a1750d30014ca14628c9e5bc6a9346ca6bcdf050ceabdeb2bdc774953c1087472616e736665726703e1df72015bdef1a1b9567d4700635f23b1f406f1", hex.EncodeToString(i.Script))
assert.Equal(t, "b2a22cd9dd7636ae23e25576866cd1d9e2f3d85a85e80874441f085cd60006d1", i.Hash.String()) assert.Equal(t, "b2a22cd9dd7636ae23e25576866cd1d9e2f3d85a85e80874441f085cd60006d1", i.Hash.ReverseString())
// Encode // Encode
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
@ -42,3 +42,37 @@ func TestEncodeDecodeInvoc(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, rawtxBytes, buf.Bytes()) assert.Equal(t, rawtxBytes, buf.Bytes())
} }
func TestEncodeDecodeInvocAttributes(t *testing.T) {
// taken from mainnet cb0b5edc7e87b3b1bd9e029112fd3ce17c16d3de20c43ca1c0c26f3add578ecb
rawtx := "d1015308005b950f5e010000140000000000000000000000000000000000000000141a1e29d6232d2148e1e71e30249835ea41eb7a3d53c1087472616e7366657267fb1c540417067c270dee32f21023aa8b9b71abce000000000000000002201a1e29d6232d2148e1e71e30249835ea41eb7a3d8110f9f504da6334935a2db42b18296d88700000014140461370f6847c4abbdddff54a3e1337e453ecc8133c882ec5b9aabcf0f47dafd3432d47e449f4efc77447ef03519b7808c450a998cca3ecc10e6536ed9db862ba23210285264b6f349f0fe86e9bb3044fde8f705b016593cf88cd5e8a802b78c7d2c950ac"
rawtxBytes, _ := hex.DecodeString(rawtx)
i := NewInvocation(30)
r := bytes.NewReader(rawtxBytes)
err := i.Decode(r)
assert.Equal(t, nil, err)
assert.Equal(t, types.Invocation, i.Type)
assert.Equal(t, 1, int(i.Version))
assert.Equal(t, 2, len(i.Attributes))
assert.Equal(t, Script, i.Attributes[0].Usage)
assert.Equal(t, "1a1e29d6232d2148e1e71e30249835ea41eb7a3d", hex.EncodeToString(i.Attributes[0].Data))
assert.Equal(t, DescriptionURL, i.Attributes[1].Usage)
assert.Equal(t, "f9f504da6334935a2db42b18296d8870", hex.EncodeToString(i.Attributes[1].Data))
assert.Equal(t, "08005b950f5e010000140000000000000000000000000000000000000000141a1e29d6232d2148e1e71e30249835ea41eb7a3d53c1087472616e7366657267fb1c540417067c270dee32f21023aa8b9b71abce", hex.EncodeToString(i.Script))
assert.Equal(t, "cb0b5edc7e87b3b1bd9e029112fd3ce17c16d3de20c43ca1c0c26f3add578ecb", i.Hash.ReverseString())
// Encode
buf := new(bytes.Buffer)
err = i.Encode(buf)
assert.Equal(t, nil, err)
assert.Equal(t, rawtxBytes, buf.Bytes())
}

View file

@ -24,7 +24,7 @@ func TestEncodeDecodeMiner(t *testing.T) {
assert.Equal(t, types.Miner, m.Type) assert.Equal(t, types.Miner, m.Type)
assert.Equal(t, uint32(571397116), m.Nonce) assert.Equal(t, uint32(571397116), m.Nonce)
assert.Equal(t, "a1f219dc6be4c35eca172e65e02d4591045220221b1543f1a4b67b9e9442c264", m.Hash.String()) assert.Equal(t, "a1f219dc6be4c35eca172e65e02d4591045220221b1543f1a4b67b9e9442c264", m.Hash.ReverseString())
// Encode // Encode
buf := new(bytes.Buffer) buf := new(bytes.Buffer)

View file

@ -28,6 +28,6 @@ func TestEncodeDecodePublish(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, rawtx, hex.EncodeToString(buf.Bytes())) assert.Equal(t, rawtx, hex.EncodeToString(buf.Bytes()))
assert.Equal(t, "5467a1fc8723ceffa8e5ee59399b02eea1df6fbaa53768c6704b90b960d223fa", publ.Hash.String()) assert.Equal(t, "5467a1fc8723ceffa8e5ee59399b02eea1df6fbaa53768c6704b90b960d223fa", publ.Hash.ReverseString())
} }

View file

@ -28,7 +28,7 @@ func TestEncodeDecodeRegister(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, rawtx, hex.EncodeToString(buf.Bytes())) assert.Equal(t, rawtx, hex.EncodeToString(buf.Bytes()))
assert.Equal(t, "0c092117b4ba47b81001712425e6e7f760a637695eaf23741ba335925b195ecd", reg.Hash.String()) assert.Equal(t, "0c092117b4ba47b81001712425e6e7f760a637695eaf23741ba335925b195ecd", reg.Hash.ReverseString())
} }
func TestEncodeDecodeGenesisRegister(t *testing.T) { func TestEncodeDecodeGenesisRegister(t *testing.T) {
@ -50,5 +50,5 @@ func TestEncodeDecodeGenesisRegister(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, rawtx, hex.EncodeToString(buf.Bytes())) assert.Equal(t, rawtx, hex.EncodeToString(buf.Bytes()))
assert.Equal(t, "c56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b", reg.Hash.String()) assert.Equal(t, "c56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b", reg.Hash.ReverseString())
} }

View file

@ -25,7 +25,7 @@ func TestEncodeDecodeState(t *testing.T) {
assert.Equal(t, 1, len(s.Inputs)) assert.Equal(t, 1, len(s.Inputs))
input := s.Inputs[0] input := s.Inputs[0]
assert.Equal(t, "a192cbabc6d613ecfcce43fd09e9197556ca5cf7d4bd1f6c65726ea9f08441cb", input.PrevHash.String()) assert.Equal(t, "a192cbabc6d613ecfcce43fd09e9197556ca5cf7d4bd1f6c65726ea9f08441cb", input.PrevHash.ReverseString())
assert.Equal(t, uint16(0), input.PrevIndex) assert.Equal(t, uint16(0), input.PrevIndex)
assert.Equal(t, 1, len(s.Descriptors)) assert.Equal(t, 1, len(s.Descriptors))
@ -43,5 +43,5 @@ func TestEncodeDecodeState(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, rawtxBytes, buf.Bytes()) assert.Equal(t, rawtxBytes, buf.Bytes())
assert.Equal(t, "8abf5ebdb9a8223b12109513647f45bd3c0a6cf1a6346d56684cff71ba308724", s.Hash.String()) assert.Equal(t, "8abf5ebdb9a8223b12109513647f45bd3c0a6cf1a6346d56684cff71ba308724", s.Hash.ReverseString())
} }

View file

@ -30,3 +30,15 @@ const (
MainNet Magic = 7630401 MainNet Magic = 7630401
TestNet Magic = 0x74746e41 TestNet Magic = 0x74746e41
) )
// String implements the stringer interface
func (m Magic) String() string {
switch m {
case MainNet:
return "Mainnet"
case TestNet:
return "Testnet"
default:
return "UnknownNet"
}
}

View file

@ -1,19 +1,43 @@
package address package address
import ( import (
"encoding/hex"
"github.com/CityOfZion/neo-go/pkg/crypto/base58" "github.com/CityOfZion/neo-go/pkg/crypto/base58"
"github.com/CityOfZion/neo-go/pkg/wire/util"
) )
// ToScriptHash converts an address to a script hash // ToScriptHash converts an address to a script hash
func ToScriptHash(address string) string { func ToScriptHash(address string) string {
a, err := Uint160Decode(address)
decodedAddressAsBytes, err := base58.Decode(address)
if err != nil { if err != nil {
return "" return ""
} }
decodedAddressAsHex := hex.EncodeToString(decodedAddressAsBytes) return a.String()
scriptHash := (decodedAddressAsHex[2:42])
return scriptHash }
// ToReverseScriptHash converts an address to a reverse script hash
func ToReverseScriptHash(address string) string {
a, err := Uint160Decode(address)
if err != nil {
return ""
}
return a.ReverseString()
}
// FromUint160 returns the "NEO address" from the given
// Uint160.
func FromUint160(u util.Uint160) (string, error) {
// Dont forget to prepend the Address version 0x17 (23) A
b := append([]byte{0x17}, u.Bytes()...)
return base58.CheckEncode(b)
}
// Uint160Decode attempts to decode the given NEO address string
// into an Uint160.
func Uint160Decode(s string) (u util.Uint160, err error) {
b, err := base58.CheckDecode(s)
if err != nil {
return u, err
}
return util.Uint160DecodeBytes(b[1:21])
} }

View file

@ -0,0 +1,17 @@
package address
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestScriptHash(t *testing.T) {
address := "AJeAEsmeD6t279Dx4n2HWdUvUmmXQ4iJvP"
hash := ToScriptHash(address)
reverseHash := ToReverseScriptHash(address)
assert.Equal(t, "b28427088a3729b2536d10122960394e8be6721f", reverseHash)
assert.Equal(t, "1f72e68b4e39602912106d53b229378a082784b2", hash)
}

View file

@ -1,126 +0,0 @@
package base58
import (
"bytes"
"fmt"
"math/big"
"github.com/CityOfZion/neo-go/pkg/wire/util/crypto/hash"
)
const prefix rune = '1'
var decodeMap = map[rune]int64{
'1': 0, '2': 1, '3': 2, '4': 3, '5': 4,
'6': 5, '7': 6, '8': 7, '9': 8, 'A': 9,
'B': 10, 'C': 11, 'D': 12, 'E': 13, 'F': 14,
'G': 15, 'H': 16, 'J': 17, 'K': 18, 'L': 19,
'M': 20, 'N': 21, 'P': 22, 'Q': 23, 'R': 24,
'S': 25, 'T': 26, 'U': 27, 'V': 28, 'W': 29,
'X': 30, 'Y': 31, 'Z': 32, 'a': 33, 'b': 34,
'c': 35, 'd': 36, 'e': 37, 'f': 38, 'g': 39,
'h': 40, 'i': 41, 'j': 42, 'k': 43, 'm': 44,
'n': 45, 'o': 46, 'p': 47, 'q': 48, 'r': 49,
's': 50, 't': 51, 'u': 52, 'v': 53, 'w': 54,
'x': 55, 'y': 56, 'z': 57,
}
// Decode decodes the base58 encoded string.
func Decode(s string) ([]byte, error) {
var (
startIndex = 0
zero = 0
)
for i, c := range s {
if c == prefix {
zero++
} else {
startIndex = i
break
}
}
var (
n = big.NewInt(0)
div = big.NewInt(58)
)
for _, c := range s[startIndex:] {
charIndex, ok := decodeMap[c]
if !ok {
return nil, fmt.Errorf(
"invalid character '%c' when decoding this base58 string: '%s'", c, s,
)
}
n.Add(n.Mul(n, div), big.NewInt(charIndex))
}
out := n.Bytes()
buf := make([]byte, (zero + len(out)))
copy(buf[zero:], out[:])
return buf, nil
}
// Encode encodes a byte slice to be a base58 encoded string.
func Encode(bytes []byte) string {
var (
lookupTable = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
x = new(big.Int).SetBytes(bytes)
r = new(big.Int)
m = big.NewInt(58)
zero = big.NewInt(0)
encoded string
)
for x.Cmp(zero) > 0 {
x.QuoRem(x, m, r)
encoded = string(lookupTable[r.Int64()]) + encoded
}
return encoded
}
// CheckDecode decodes the given string.
func CheckDecode(s string) (b []byte, err error) {
b, err = Decode(s)
if err != nil {
return nil, err
}
for i := 0; i < len(s); i++ {
if s[i] != '1' {
break
}
b = append([]byte{0x00}, b...)
}
if len(b) < 5 {
return nil, fmt.Errorf("Invalid base-58 check string: missing checksum. -1")
}
hash, err := hash.DoubleSha256(b[:len(b)-4])
if err != nil {
return nil, fmt.Errorf("Could not double sha256 data")
}
if bytes.Compare(hash[0:4], b[len(b)-4:]) != 0 {
return nil, fmt.Errorf("Invalid base-58 check string: invalid checksum. -2")
}
// Strip the 4 byte long hash.
b = b[:len(b)-4]
return b, nil
}
// CheckEncode encodes b into a base-58 check encoded string.
func CheckEncode(b []byte) (string, error) {
hash, err := hash.DoubleSha256(b)
if err != nil {
return "", fmt.Errorf("Could not double sha256 data")
}
b = append(b, hash[0:4]...)
return Encode(b), nil
}

View file

@ -1,32 +0,0 @@
package base58
import (
"encoding/hex"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDecode(t *testing.T) {
input := "1F1tAaz5x1HUXrCNLbtMDqcw6o5GNn4xqX"
data, err := Decode(input)
if err != nil {
t.Fatal(err)
}
expected := "0099bc78ba577a95a11f1a344d4d2ae55f2f857b989ea5e5e2"
actual := hex.EncodeToString(data)
assert.Equal(t, expected, actual)
}
func TestEncode(t *testing.T) {
input := "0099bc78ba577a95a11f1a344d4d2ae55f2f857b989ea5e5e2"
inputBytes, _ := hex.DecodeString(input)
data := Encode(inputBytes)
expected := "F1tAaz5x1HUXrCNLbtMDqcw6o5GNn4xqX" // Removed the 1 as it is not checkEncoding
actual := data
assert.Equal(t, expected, actual)
}

View file

@ -1,83 +0,0 @@
package hash
import (
"crypto/sha256"
"io"
"github.com/CityOfZion/neo-go/pkg/wire/util"
"golang.org/x/crypto/ripemd160"
)
// Sha256 hashes the byte slice using sha256
func Sha256(data []byte) (util.Uint256, error) {
var hash util.Uint256
hasher := sha256.New()
hasher.Reset()
_, err := hasher.Write(data)
hash, err = util.Uint256DecodeBytes(hasher.Sum(nil))
if err != nil {
return hash, err
}
return hash, nil
}
// DoubleSha256 hashes the underlying data twice using sha256
func DoubleSha256(data []byte) (util.Uint256, error) {
var hash util.Uint256
h1, err := Sha256(data)
if err != nil {
return hash, err
}
hash, err = Sha256(h1.Bytes())
if err != nil {
return hash, err
}
return hash, nil
}
// RipeMD160 hashes the underlying data using ripemd160
func RipeMD160(data []byte) (util.Uint160, error) {
var hash util.Uint160
hasher := ripemd160.New()
hasher.Reset()
_, err := io.WriteString(hasher, string(data))
hash, err = util.Uint160DecodeBytes(hasher.Sum(nil))
if err != nil {
return hash, err
}
return hash, nil
}
//Hash160 hashes the underlying data using sha256 then ripemd160
func Hash160(data []byte) (util.Uint160, error) {
var hash util.Uint160
h1, err := Sha256(data)
h2, err := RipeMD160(h1.Bytes())
hash, err = util.Uint160DecodeBytes(h2.Bytes())
if err != nil {
return hash, err
}
return hash, nil
}
// Checksum calculates the checksum of the byte slice using sha256
func Checksum(data []byte) ([]byte, error) {
hash, err := Sum(data)
if err != nil {
return nil, err
}
return hash[:4], nil
}
// Sum calculates the Sum of the data by using double sha256
func Sum(b []byte) (util.Uint256, error) {
hash, err := DoubleSha256((b))
return hash, err
}

View file

@ -1,62 +0,0 @@
package hash
import (
"encoding/hex"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSha256(t *testing.T) {
input := []byte("hello")
data, err := Sha256(input)
if err != nil {
t.Fatal(err)
}
expected := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
actual := hex.EncodeToString(data.Bytes()) // MARK: In the DecodeBytes function, there is a bytes reverse, not sure why?
assert.Equal(t, expected, actual)
}
func TestHashDoubleSha256(t *testing.T) {
input := []byte("hello")
data, err := DoubleSha256(input)
if err != nil {
t.Fatal(err)
}
firstSha, _ := Sha256(input)
doubleSha, _ := Sha256(firstSha.Bytes())
expected := hex.EncodeToString(doubleSha.Bytes())
actual := hex.EncodeToString(data.Bytes())
assert.Equal(t, expected, actual)
}
func TestHashRipeMD160(t *testing.T) {
input := []byte("hello")
data, err := RipeMD160(input)
if err != nil {
t.Fatal(err)
}
expected := "108f07b8382412612c048d07d13f814118445acd"
actual := hex.EncodeToString(data.Bytes())
assert.Equal(t, expected, actual)
}
func TestHash160(t *testing.T) {
input := "02cccafb41b220cab63fd77108d2d1ebcffa32be26da29a04dca4996afce5f75db"
publicKeyBytes, _ := hex.DecodeString(input)
data, err := Hash160(publicKeyBytes)
if err != nil {
t.Fatal(err)
}
expected := "c8e2b685cc70ec96743b55beb9449782f8f775d8"
actual := hex.EncodeToString(data.Bytes())
assert.Equal(t, expected, actual)
}

View file

@ -68,6 +68,11 @@ func (u Uint160) String() string {
return hex.EncodeToString(u.Bytes()) return hex.EncodeToString(u.Bytes())
} }
// ReverseString implements the stringer interface.
func (u Uint160) ReverseString() string {
return hex.EncodeToString(u.BytesReverse())
}
// Equals returns true if both Uint256 values are the same. // Equals returns true if both Uint256 values are the same.
func (u Uint160) Equals(other Uint160) bool { func (u Uint160) Equals(other Uint160) bool {
for i := 0; i < uint160Size; i++ { for i := 0; i < uint160Size; i++ {

View file

@ -48,3 +48,15 @@ func TestUInt160Equals(t *testing.T) {
t.Fatalf("%s and %s must be equal", ua, ua) t.Fatalf("%s and %s must be equal", ua, ua)
} }
} }
func TestUInt160String(t *testing.T) {
hexStr := "b28427088a3729b2536d10122960394e8be6721f"
hexRevStr := "1f72e68b4e39602912106d53b229378a082784b2"
val, err := Uint160DecodeString(hexStr)
assert.Nil(t, err)
assert.Equal(t, hexStr, val.String())
assert.Equal(t, hexRevStr, val.ReverseString())
}

View file

@ -63,6 +63,11 @@ func (u Uint256) Equals(other Uint256) bool {
// String implements the stringer interface. // String implements the stringer interface.
func (u Uint256) String() string { func (u Uint256) String() string {
return hex.EncodeToString(u.Bytes())
}
// ReverseString displays a reverse string representation of Uint256.
func (u Uint256) ReverseString() string {
return hex.EncodeToString(slice.Reverse(u.Bytes())) return hex.EncodeToString(slice.Reverse(u.Bytes()))
} }

View file

@ -13,7 +13,7 @@ func TestUint256DecodeString(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, hexStr, val.Reverse().String()) assert.Equal(t, hexStr, val.String())
} }
func TestUint256DecodeBytes(t *testing.T) { func TestUint256DecodeBytes(t *testing.T) {
@ -26,7 +26,7 @@ func TestUint256DecodeBytes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, hexStr, val.Reverse().String()) assert.Equal(t, hexStr, val.String())
} }
func TestUInt256Equals(t *testing.T) { func TestUInt256Equals(t *testing.T) {