Refactor of the Go node (#44)

* added headersOp for safely processing headers

* Better handling of protocol messages.

* housekeeping + cleanup tests

* Added more blockchain logic + unit tests

* fixed unreachable error.

* added structured logging for all (node) components.

* added relay flag + bumped version
This commit is contained in:
Anthony De Meulemeester 2018-03-09 16:55:25 +01:00 committed by GitHub
parent b2a5e34aac
commit 4023661cf1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
43 changed files with 1497 additions and 1265 deletions

0
Normal file
View file

26
Gopkg.lock generated
View file

@ -19,12 +19,36 @@
revision = "346938d642f2ec3594ed81d874461961cd0faa76" revision = "346938d642f2ec3594ed81d874461961cd0faa76"
version = "v1.1.0" version = "v1.1.0"
[[projects]]
name = "github.com/go-kit/kit"
packages = ["log"]
revision = "4dc7be5d2d12881735283bcab7352178e190fc71"
version = "v0.6.0"
[[projects]]
name = "github.com/go-logfmt/logfmt"
packages = ["."]
revision = "390ab7935ee28ec6b286364bba9b4dd6410cb3d5"
version = "v0.3.0"
[[projects]]
name = "github.com/go-stack/stack"
packages = ["."]
revision = "259ab82a6cad3992b4e21ff5cac294ccb06474bc"
version = "v1.7.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "github.com/golang/snappy" name = "github.com/golang/snappy"
packages = ["."] packages = ["."]
revision = "553a641470496b2327abcac10b36396bd98e45c9" revision = "553a641470496b2327abcac10b36396bd98e45c9"
[[projects]]
branch = "master"
name = "github.com/kr/logfmt"
packages = ["."]
revision = "b84e30acd515aadc4b783ad4ff83aff3299bdfe0"
[[projects]] [[projects]]
name = "github.com/pmezard/go-difflib" name = "github.com/pmezard/go-difflib"
packages = ["difflib"] packages = ["difflib"]
@ -98,6 +122,6 @@
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
inputs-digest = "069a738aa1487766b26f9efb8103d2ce0526d43c83049cb5b792f0edf91568de" inputs-digest = "53597073e919ad7bf52895a19f8b8526d12d666862fb1d36b4a9756e0499da5a"
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View file

@ -51,3 +51,7 @@
[[constraint]] [[constraint]]
name = "github.com/stretchr/testify" name = "github.com/stretchr/testify"
version = "1.2.1" version = "1.2.1"
[[constraint]]
name = "github.com/go-kit/kit"
version = "0.6.0"

View file

@ -19,7 +19,7 @@ push-tag:
git push origin ${BRANCH} --tags git push origin ${BRANCH} --tags
run: build run: build
./bin/neo-go node -seed ${SEEDS} -tcp ${PORT} ./bin/neo-go node -seed ${SEEDS} -tcp ${PORT} --relay true
test: test:
@go test ./... -cover @go test ./... -cover

View file

@ -1 +1 @@
0.25.0 0.26.0

View file

@ -16,6 +16,7 @@ func NewCommand() cli.Command {
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.IntFlag{Name: "tcp"}, cli.IntFlag{Name: "tcp"},
cli.IntFlag{Name: "rpc"}, cli.IntFlag{Name: "rpc"},
cli.BoolFlag{Name: "relay, r"},
cli.StringFlag{Name: "seed"}, cli.StringFlag{Name: "seed"},
cli.BoolFlag{Name: "privnet, p"}, cli.BoolFlag{Name: "privnet, p"},
cli.BoolFlag{Name: "mainnet, m"}, cli.BoolFlag{Name: "mainnet, m"},
@ -25,12 +26,6 @@ func NewCommand() cli.Command {
} }
func startServer(ctx *cli.Context) error { func startServer(ctx *cli.Context) error {
opts := network.StartOpts{
Seeds: parseSeeds(ctx.String("seed")),
TCP: ctx.Int("tcp"),
RPC: ctx.Int("rpc"),
}
net := network.ModePrivNet net := network.ModePrivNet
if ctx.Bool("testnet") { if ctx.Bool("testnet") {
net = network.ModeTestNet net = network.ModeTestNet
@ -39,8 +34,16 @@ func startServer(ctx *cli.Context) error {
net = network.ModeMainNet net = network.ModeMainNet
} }
s := network.NewServer(net) cfg := network.Config{
s.Start(opts) UserAgent: "/NEO-GO:0.26.0/",
ListenTCP: uint16(ctx.Int("tcp")),
Seeds: parseSeeds(ctx.String("seed")),
Net: net,
Relay: ctx.Bool("relay"),
}
s := network.NewServer(cfg)
s.Start()
return nil return nil
} }

15
main.go Normal file
View file

@ -0,0 +1,15 @@
package main
import (
"os"
log "github.com/go-kit/kit/log"
)
func main() {
logger := log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr))
logger.Log("hello", true)
logger = log.With(logger, "module", "node")
logger.Log("foo", true)
}

View file

@ -33,6 +33,9 @@ type BlockBase struct {
_ uint8 // padding _ uint8 // padding
// Script used to validate the block // Script used to validate the block
Script *transaction.Witness Script *transaction.Witness
// hash of this block, created when binary encoded.
hash util.Uint256
} }
// DecodeBinary implements the payload interface. // DecodeBinary implements the payload interface.
@ -68,19 +71,35 @@ func (b *BlockBase) DecodeBinary(r io.Reader) error {
} }
b.Script = &transaction.Witness{} b.Script = &transaction.Witness{}
return b.Script.DecodeBinary(r) if err := b.Script.DecodeBinary(r); err != nil {
return err
}
// Make the hash of the block here so we dont need to do this
// again.
hash, err := b.createHash()
if err != nil {
return err
}
b.hash = hash
return nil
} }
// Hash returns the hash of the block. // Hash return the hash of the block.
func (b *BlockBase) Hash() util.Uint256 {
return b.hash
}
// createHash creates the hash of the block.
// When calculating the hash value of the block, instead of calculating the entire block, // When calculating the hash value of the block, instead of calculating the entire block,
// only first seven fields in the block head will be calculated, which are // only first seven fields in the block head will be calculated, which are
// version, PrevBlock, MerkleRoot, timestamp, and height, the nonce, NextMiner. // version, PrevBlock, MerkleRoot, timestamp, and height, the nonce, NextMiner.
// Since MerkleRoot already contains the hash value of all transactions, // Since MerkleRoot already contains the hash value of all transactions,
// the modification of transaction will influence the hash value of the block. // the modification of transaction will influence the hash value of the block.
func (b *BlockBase) Hash() (hash util.Uint256, err error) { func (b *BlockBase) createHash() (hash util.Uint256, err error) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err = b.encodeHashableFields(buf); err != nil { if err = b.encodeHashableFields(buf); err != nil {
return return hash, err
} }
// Double hash the encoded fields. // Double hash the encoded fields.
@ -92,15 +111,25 @@ func (b *BlockBase) Hash() (hash util.Uint256, err error) {
// encodeHashableFields will only encode the fields used for hashing. // encodeHashableFields will only encode the fields used for hashing.
// see Hash() for more information about the fields. // see Hash() for more information about the fields.
func (b *BlockBase) encodeHashableFields(w io.Writer) error { func (b *BlockBase) encodeHashableFields(w io.Writer) error {
err := binary.Write(w, binary.LittleEndian, &b.Version) if err := binary.Write(w, binary.LittleEndian, &b.Version); err != nil {
err = binary.Write(w, binary.LittleEndian, &b.PrevHash) return err
err = binary.Write(w, binary.LittleEndian, &b.MerkleRoot) }
err = binary.Write(w, binary.LittleEndian, &b.Timestamp) if err := binary.Write(w, binary.LittleEndian, &b.PrevHash); err != nil {
err = binary.Write(w, binary.LittleEndian, &b.Index) return err
err = binary.Write(w, binary.LittleEndian, &b.ConsensusData) }
err = binary.Write(w, binary.LittleEndian, &b.NextConsensus) if err := binary.Write(w, binary.LittleEndian, &b.MerkleRoot); err != nil {
return err
return err }
if err := binary.Write(w, binary.LittleEndian, &b.Timestamp); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, &b.Index); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, &b.ConsensusData); err != nil {
return err
}
return binary.Write(w, binary.LittleEndian, &b.NextConsensus)
} }
// EncodeBinary implements the Payload interface // EncodeBinary implements the Payload interface
@ -108,21 +137,16 @@ func (b *BlockBase) EncodeBinary(w io.Writer) error {
if err := b.encodeHashableFields(w); err != nil { if err := b.encodeHashableFields(w); err != nil {
return err return err
} }
// padding
if err := binary.Write(w, binary.LittleEndian, uint8(1)); err != nil { if err := binary.Write(w, binary.LittleEndian, uint8(1)); err != nil {
return err return err
} }
// script
return b.Script.EncodeBinary(w) return b.Script.EncodeBinary(w)
} }
// Header holds the head info of a block // Header holds the head info of a block
type Header struct { type Header struct {
BlockBase BlockBase
// fixed to 0 _ uint8 // padding fixed to 0
_ uint8 // padding
} }
// Verify the integrity of the header // Verify the integrity of the header
@ -150,15 +174,12 @@ func (h *Header) EncodeBinary(w io.Writer) error {
if err := h.BlockBase.EncodeBinary(w); err != nil { if err := h.BlockBase.EncodeBinary(w); err != nil {
return err return err
} }
// padding
return binary.Write(w, binary.LittleEndian, uint8(0)) return binary.Write(w, binary.LittleEndian, uint8(0))
} }
// Block represents one block in the chain. // Block represents one block in the chain.
type Block struct { type Block struct {
BlockBase BlockBase
// transaction list
Transactions []*transaction.Transaction Transactions []*transaction.Transaction
} }
@ -205,11 +226,10 @@ func (b *Block) DecodeBinary(r io.Reader) error {
lentx := util.ReadVarUint(r) lentx := util.ReadVarUint(r)
b.Transactions = make([]*transaction.Transaction, lentx) b.Transactions = make([]*transaction.Transaction, lentx)
for i := 0; i < int(lentx); i++ { for i := 0; i < int(lentx); i++ {
tx := &transaction.Transaction{} b.Transactions[i] = &transaction.Transaction{}
if err := tx.DecodeBinary(r); err != nil { if err := b.Transactions[i].DecodeBinary(r); err != nil {
return err return err
} }
b.Transactions[i] = tx
} }
return nil return nil

View file

@ -8,6 +8,7 @@ import (
"github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
"github.com/stretchr/testify/assert"
) )
func TestDecodeBlock(t *testing.T) { func TestDecodeBlock(t *testing.T) {
@ -29,63 +30,24 @@ func TestDecodeBlock(t *testing.T) {
if err := block.DecodeBinary(bytes.NewReader(rawBlockBytes)); err != nil { if err := block.DecodeBinary(bytes.NewReader(rawBlockBytes)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if block.Index != uint32(rawBlockIndex) { assert.Equal(t, uint32(rawBlockIndex), block.Index)
t.Fatalf("expected the index to the block to be %d got %d", rawBlockIndex, block.Index) assert.Equal(t, uint32(rawBlockTimestamp), block.Timestamp)
} assert.Equal(t, uint64(rawBlockConsensusData), block.ConsensusData)
if block.Timestamp != uint32(rawBlockTimestamp) { assert.Equal(t, rawBlockPrevHash, block.PrevHash.String())
t.Fatalf("expected timestamp to be %d got %d", rawBlockTimestamp, block.Timestamp) assert.Equal(t, rawBlockHash, block.Hash().String())
}
if block.ConsensusData != uint64(rawBlockConsensusData) {
t.Fatalf("expected consensus data to be %d got %d", rawBlockConsensusData, block.ConsensusData)
}
if block.PrevHash.String() != rawBlockPrevHash {
t.Fatalf("expected prev block hash to be %s got %s", rawBlockPrevHash, block.PrevHash)
}
hash, err := block.Hash()
if err != nil {
t.Fatal(err)
}
if hash.String() != rawBlockHash {
t.Fatalf("expected hash of the block to be %s got %s", rawBlockHash, hash)
}
}
func newBlockBase() BlockBase {
return BlockBase{
Version: 0,
PrevHash: sha256.Sum256([]byte("a")),
MerkleRoot: sha256.Sum256([]byte("b")),
Timestamp: 999,
Index: 1,
ConsensusData: 1111,
NextConsensus: util.Uint160{},
Script: &transaction.Witness{
VerificationScript: []byte{0x0},
InvocationScript: []byte{0x1},
},
}
} }
func TestHashBlockEqualsHashHeader(t *testing.T) { func TestHashBlockEqualsHashHeader(t *testing.T) {
base := newBlockBase() block := newBlock(0)
b := &Block{BlockBase: base} assert.Equal(t, block.Hash(), block.Header().Hash())
head := &Header{BlockBase: base}
bhash, _ := b.Hash()
headhash, _ := head.Hash()
if bhash != headhash {
t.Fatalf("expected both hashes to be equal %s and %s", bhash, headhash)
}
} }
func TestBlockVerify(t *testing.T) { func TestBlockVerify(t *testing.T) {
block := &Block{ block := newBlock(
BlockBase: newBlockBase(), 0,
Transactions: []*transaction.Transaction{ newTX(transaction.MinerType),
{Type: transaction.MinerType}, newTX(transaction.IssueType),
{Type: transaction.IssueType}, )
},
}
if !block.Verify(false) { if !block.Verify(false) {
t.Fatal("block should be verified") t.Fatal("block should be verified")
@ -109,3 +71,19 @@ func TestBlockVerify(t *testing.T) {
t.Fatal("block should not by verified") t.Fatal("block should not by verified")
} }
} }
func newBlockBase() BlockBase {
return BlockBase{
Version: 0,
PrevHash: sha256.Sum256([]byte("a")),
MerkleRoot: sha256.Sum256([]byte("b")),
Timestamp: 999,
Index: 1,
ConsensusData: 1111,
NextConsensus: util.Uint160{},
Script: &transaction.Witness{
VerificationScript: []byte{0x0},
InvocationScript: []byte{0x1},
},
}
}

View file

@ -3,17 +3,19 @@ package core
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"log" "fmt"
"sync" "os"
"sync/atomic"
"time" "time"
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
) )
// tuning parameters // tuning parameters
const ( const (
secondsPerBlock = 15 secondsPerBlock = 15
writeHdrBatchCnt = 2000 headerBatchCount = 2000
) )
var ( var (
@ -22,188 +24,217 @@ var (
// Blockchain holds the chain. // Blockchain holds the chain.
type Blockchain struct { type Blockchain struct {
logger *log.Logger logger log.Logger
// Any object that satisfies the BlockchainStorer interface. // Any object that satisfies the BlockchainStorer interface.
Store Store
// current index of the heighest block // Current index/height of the highest block.
currentBlockHeight uint32 // Read access should always be called by BlockHeight().
// Writes access should only happen in persist().
blockHeight uint32
// number of headers stored // Number of headers stored.
storedHeaderCount uint32 storedHeaderCount uint32
mtx sync.RWMutex blockCache *Cache
// index of headers hashes startHash util.Uint256
headerIndex []util.Uint256
// Only for operating on the headerList.
headersOp chan headersOpFunc
headersOpDone chan struct{}
} }
// NewBlockchain returns a pointer to a Blockchain. type headersOpFunc func(headerList *HeaderHashList)
func NewBlockchain(s Store, l *log.Logger, startHash util.Uint256) *Blockchain {
// NewBlockchain creates a new Blockchain object.
func NewBlockchain(s Store, startHash util.Uint256) *Blockchain {
logger := log.NewLogfmtLogger(os.Stderr)
logger = log.With(logger, "component", "blockchain")
bc := &Blockchain{ bc := &Blockchain{
logger: l, logger: logger,
Store: s, Store: s,
headersOp: make(chan headersOpFunc),
headersOpDone: make(chan struct{}),
startHash: startHash,
blockCache: NewCache(),
} }
go bc.run()
// Starthash is 0, so we will create the genesis block. bc.init()
if startHash.Equals(util.Uint256{}) {
bc.logger.Fatal("genesis block not yet implemented")
}
bc.headerIndex = []util.Uint256{startHash}
return bc return bc
} }
// genesisBlock creates the genesis block for the chain. func (bc *Blockchain) init() {
// hash of the genesis block: // for the initial header, for now
// d42561e3d30e15be6400b6df2f328e02d2bf6354c41dce433bc57687c82144bf bc.storedHeaderCount = 1
func (bc *Blockchain) genesisBlock() *Block { }
timestamp := uint32(time.Date(2016, 7, 15, 15, 8, 21, 0, time.UTC).Unix())
// TODO: for testing I will hardcode the merkleroot. func (bc *Blockchain) run() {
// This let's me focus on the bringing all the puzzle pieces headerList := NewHeaderHashList(bc.startHash)
// togheter much faster. for {
// For more information about the genesis block: select {
// https://neotracker.io/block/height/0 case op := <-bc.headersOp:
mr, _ := util.Uint256DecodeString("803ff4abe3ea6533bcc0be574efa02f83ae8fdc651c879056b0d9be336c01bf4") op(headerList)
bc.headersOpDone <- struct{}{}
return &Block{ }
BlockBase: BlockBase{
Version: 0,
PrevHash: util.Uint256{},
MerkleRoot: mr,
Timestamp: timestamp,
Index: 0,
ConsensusData: 2083236893, // nioctib ^^
NextConsensus: util.Uint160{}, // todo
},
} }
} }
// AddBlock (to be continued after headers is finished..)
func (bc *Blockchain) AddBlock(block *Block) error { func (bc *Blockchain) AddBlock(block *Block) error {
// TODO: caching if !bc.blockCache.Has(block.Hash()) {
headerLen := len(bc.headerIndex) bc.blockCache.Add(block.Hash(), block)
}
headerLen := int(bc.HeaderHeight() + 1)
if int(block.Index-1) >= headerLen { if int(block.Index-1) >= headerLen {
return nil return nil
} }
if int(block.Index) == headerLen { if int(block.Index) == headerLen {
// todo: if (VerifyBlocks && !block.Verify()) return false; // todo: if (VerifyBlocks && !block.Verify()) return false;
} }
return bc.AddHeaders(block.Header())
if int(block.Index) < headerLen {
return nil
}
return nil
} }
func (bc *Blockchain) addHeader(header *Header) error { func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {
return bc.AddHeaders(header) var (
start = time.Now()
batch = Batch{}
)
bc.headersOp <- func(headerList *HeaderHashList) {
for _, h := range headers {
if int(h.Index-1) >= headerList.Len() {
err = fmt.Errorf(
"height of block higher then current header height %d > %d\n",
h.Index, headerList.Len(),
)
return
}
if int(h.Index) < headerList.Len() {
continue
}
if !h.Verify() {
err = fmt.Errorf("header %v is invalid", h)
return
}
if err = bc.processHeader(h, batch, headerList); err != nil {
return
}
}
// TODO: Implement caching strategy.
if len(batch) > 0 {
if err = bc.writeBatch(batch); err != nil {
return
}
bc.logger.Log(
"msg", "done processing headers",
"index", headerList.Len()-1,
"took", time.Since(start).Seconds(),
)
}
}
<-bc.headersOpDone
return err
} }
// AddHeaders processes the given headers. // processHeader processes the given header. Note that this is only thread safe
func (bc *Blockchain) AddHeaders(headers ...*Header) error { // if executed in headers operation.
start := time.Now() func (bc *Blockchain) processHeader(h *Header, batch Batch, headerList *HeaderHashList) error {
headerList.Add(h.Hash())
bc.mtx.Lock()
defer bc.mtx.Unlock()
batch := Batch{}
for _, h := range headers {
if int(h.Index-1) >= len(bc.headerIndex) {
bc.logger.Printf("height of block higher then header index %d %d\n",
h.Index, len(bc.headerIndex))
break
}
if int(h.Index) < len(bc.headerIndex) {
continue
}
if !h.Verify() {
bc.logger.Printf("header %v is invalid", h)
break
}
if err := bc.processHeader(h, batch); err != nil {
return err
}
}
// TODO: Implement caching strategy.
if len(batch) > 0 {
// Write all batches.
if err := bc.writeBatch(batch); err != nil {
return err
}
bc.logger.Printf("done processing headers up to index %d took %f Seconds",
bc.HeaderHeight(), time.Since(start).Seconds())
}
return nil
}
// processHeader processes 1 header.
func (bc *Blockchain) processHeader(h *Header, batch Batch) error {
hash, err := h.Hash()
if err != nil {
return err
}
bc.headerIndex = append(bc.headerIndex, hash)
for int(h.Index)-writeHdrBatchCnt >= int(bc.storedHeaderCount) {
// hdrsToWrite = bc.headerIndex[bc.storedHeaderCount : bc.storedHeaderCount+writeHdrBatchCnt]
// NOTE: from original #c to be implemented:
//
// w.Write(header_index.Skip((int)stored_header_count).Take(2000).ToArray());
// w.Flush();
// batch.Put(SliceBuilder.Begin(DataEntryPrefix.IX_HeaderHashList).Add(stored_header_count), ms.ToArray());
bc.storedHeaderCount += writeHdrBatchCnt
}
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
for int(h.Index)-headerBatchCount >= int(bc.storedHeaderCount) {
if err := headerList.Write(buf, int(bc.storedHeaderCount), headerBatchCount); err != nil {
return err
}
key := makeEntryPrefixInt(preIXHeaderHashList, int(bc.storedHeaderCount))
batch[&key] = buf.Bytes()
bc.storedHeaderCount += headerBatchCount
buf.Reset()
}
buf.Reset()
if err := h.EncodeBinary(buf); err != nil { if err := h.EncodeBinary(buf); err != nil {
return err return err
} }
preBlock := preDataBlock.add(hash.BytesReverse()) key := makeEntryPrefix(preDataBlock, h.Hash().BytesReverse())
batch[&preBlock] = buf.Bytes() batch[&key] = buf.Bytes()
preHeader := preSYSCurrentHeader.toSlice() key = preSYSCurrentHeader.bytes()
batch[&preHeader] = hashAndIndexToBytes(hash, h.Index) batch[&key] = hashAndIndexToBytes(h.Hash(), h.Index)
return nil return nil
} }
// CurrentBlockHash return the lastest hash in the header index. func (bc *Blockchain) persistBlock(block *Block) error {
func (bc *Blockchain) CurrentBlockHash() (hash util.Uint256) { bc.blockHeight = block.Index
if len(bc.headerIndex) == 0 { return nil
return }
}
if len(bc.headerIndex) < int(bc.currentBlockHeight) {
return
}
return bc.headerIndex[bc.currentBlockHeight] func (bc *Blockchain) persist() (err error) {
var (
persisted = 0
lenCache = bc.blockCache.Len()
)
for lenCache > persisted {
if bc.HeaderHeight()+1 <= bc.BlockHeight() {
break
}
bc.headersOp <- func(headerList *HeaderHashList) {
hash := headerList.Get(int(bc.BlockHeight() + 1))
if block, ok := bc.blockCache.GetBlock(hash); ok {
if err = bc.persistBlock(block); err != nil {
return
}
bc.blockCache.Delete(hash)
persisted++
} else {
bc.logger.Log(
"msg", "block not found in cache",
"hash", block.Hash(),
)
}
}
<-bc.headersOpDone
}
return
}
// CurrentBlockHash returns the heighest processed block hash.
func (bc *Blockchain) CurrentBlockHash() (hash util.Uint256) {
bc.headersOp <- func(headerList *HeaderHashList) {
hash = headerList.Get(int(bc.BlockHeight()))
}
<-bc.headersOpDone
return
} }
// CurrentHeaderHash returns the hash of the latest known header. // CurrentHeaderHash returns the hash of the latest known header.
func (bc *Blockchain) CurrentHeaderHash() (hash util.Uint256) { func (bc *Blockchain) CurrentHeaderHash() (hash util.Uint256) {
return bc.headerIndex[len(bc.headerIndex)-1] bc.headersOp <- func(headerList *HeaderHashList) {
hash = headerList.Last()
}
<-bc.headersOpDone
return
} }
// BlockHeight return the height/index of the latest block this node has. // BlockHeight returns the height/index of the highest block.
func (bc *Blockchain) BlockHeight() uint32 { func (bc *Blockchain) BlockHeight() uint32 {
return bc.currentBlockHeight return atomic.LoadUint32(&bc.blockHeight)
} }
// HeaderHeight returns the current index of the headers. // HeaderHeight returns the index/height of the highest header.
func (bc *Blockchain) HeaderHeight() uint32 { func (bc *Blockchain) HeaderHeight() (n uint32) {
return uint32(len(bc.headerIndex)) - 1 bc.headersOp <- func(headerList *HeaderHashList) {
n = uint32(headerList.Len() - 1)
}
<-bc.headersOpDone
return
} }
func hashAndIndexToBytes(h util.Uint256, index uint32) []byte { func hashAndIndexToBytes(h util.Uint256, index uint32) []byte {

View file

@ -1,48 +1,70 @@
package core package core
import ( import (
"log"
"os"
"testing" "testing"
"github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
"github.com/stretchr/testify/assert"
) )
func TestNewBlockchain(t *testing.T) { func TestNewBlockchain(t *testing.T) {
startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099") startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099")
bc := NewBlockchain(nil, nil, startHash) bc := NewBlockchain(nil, startHash)
want := uint32(0) assert.Equal(t, uint32(0), bc.BlockHeight())
if have := bc.BlockHeight(); want != have { assert.Equal(t, uint32(0), bc.HeaderHeight())
t.Fatalf("expected %d got %d", want, have) assert.Equal(t, uint32(1), bc.storedHeaderCount)
} assert.Equal(t, startHash, bc.startHash)
if have := bc.HeaderHeight(); want != have {
t.Fatalf("expected %d got %d", want, have)
}
if have := bc.storedHeaderCount; want != have {
t.Fatalf("expected %d got %d", want, have)
}
if !bc.CurrentBlockHash().Equals(startHash) {
t.Fatalf("expected current block hash to be %d got %s", startHash, bc.CurrentBlockHash())
}
} }
func TestAddHeaders(t *testing.T) { func TestAddHeaders(t *testing.T) {
startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099") bc := newTestBC()
bc := NewBlockchain(NewMemoryStore(), log.New(os.Stdout, "", 0), startHash) h1 := newBlock(1).Header()
h2 := newBlock(2).Header()
h1 := &Header{BlockBase: BlockBase{Version: 0, Index: 1, Script: &transaction.Witness{}}} h3 := newBlock(3).Header()
h2 := &Header{BlockBase: BlockBase{Version: 0, Index: 2, Script: &transaction.Witness{}}}
h3 := &Header{BlockBase: BlockBase{Version: 0, Index: 3, Script: &transaction.Witness{}}}
if err := bc.AddHeaders(h1, h2, h3); err != nil { if err := bc.AddHeaders(h1, h2, h3); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if want, have := h3.Index, bc.HeaderHeight(); want != have {
t.Fatalf("expected header height of %d got %d", want, have) assert.Equal(t, 0, bc.blockCache.Len())
} assert.Equal(t, h3.Index, bc.HeaderHeight())
if want, have := uint32(0), bc.storedHeaderCount; want != have { assert.Equal(t, uint32(1), bc.storedHeaderCount)
t.Fatalf("expected stored header count to be %d got %d", want, have) assert.Equal(t, uint32(0), bc.BlockHeight())
} assert.Equal(t, h3.Hash(), bc.CurrentHeaderHash())
}
func TestAddBlock(t *testing.T) {
bc := newTestBC()
blocks := []*Block{
newBlock(1),
newBlock(2),
newBlock(3),
}
for i := 0; i < len(blocks); i++ {
if err := bc.AddBlock(blocks[i]); err != nil {
t.Fatal(err)
}
}
lastBlock := blocks[len(blocks)-1]
assert.Equal(t, 3, bc.blockCache.Len())
assert.Equal(t, lastBlock.Index, bc.HeaderHeight())
assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash())
assert.Equal(t, uint32(1), bc.storedHeaderCount)
if err := bc.persist(); err != nil {
t.Fatal(err)
}
assert.Equal(t, lastBlock.Index, bc.BlockHeight())
assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash())
assert.Equal(t, 0, bc.blockCache.Len())
}
func newTestBC() *Blockchain {
startHash, _ := util.Uint256DecodeString("a")
bc := NewBlockchain(NewMemoryStore(), startHash)
return bc
} }

73
pkg/core/cache.go Normal file
View file

@ -0,0 +1,73 @@
package core
import (
"sync"
"github.com/CityOfZion/neo-go/pkg/util"
)
// Cache is data structure with fixed type key of Uint256, but has a
// generic value. Used for block and header cash types.
type Cache struct {
lock sync.RWMutex
m map[util.Uint256]interface{}
}
// NewCache returns a ready to use Cache object.
func NewCache() *Cache {
return &Cache{
m: make(map[util.Uint256]interface{}),
}
}
// GetBlock will return a Block type from the cache.
func (c *Cache) GetBlock(h util.Uint256) (block *Block, ok bool) {
c.lock.RLock()
defer c.lock.RUnlock()
return c.getBlock(h)
}
func (c *Cache) getBlock(h util.Uint256) (block *Block, ok bool) {
if v, b := c.m[h]; b {
block, ok = v.(*Block)
return
}
return
}
// Add adds the given hash along with its value to the cache.
func (c *Cache) Add(h util.Uint256, v interface{}) {
c.lock.Lock()
defer c.lock.Unlock()
c.add(h, v)
}
func (c *Cache) add(h util.Uint256, v interface{}) {
c.m[h] = v
}
func (c *Cache) has(h util.Uint256) bool {
_, ok := c.m[h]
return ok
}
// Hash returns whether the cach contains the given hash.
func (c *Cache) Has(h util.Uint256) bool {
c.lock.Lock()
defer c.lock.Unlock()
return c.has(h)
}
// Len return the number of items present in the cache.
func (c *Cache) Len() int {
c.lock.RLock()
defer c.lock.RUnlock()
return len(c.m)
}
// Delete removes the item out of the cache.
func (c *Cache) Delete(h util.Uint256) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.m, h)
}

View file

@ -0,0 +1,66 @@
package core
import (
"encoding/binary"
"io"
"github.com/CityOfZion/neo-go/pkg/util"
)
// A HeaderHashList represents a list of header hashes.
type HeaderHashList struct {
hashes []util.Uint256
}
// NewHeaderHashList return a new pointer to a HeaderHashList.
func NewHeaderHashList(hashes ...util.Uint256) *HeaderHashList {
return &HeaderHashList{
hashes: hashes,
}
}
// Add appends the given hash to the list of hashes.
func (l *HeaderHashList) Add(h util.Uint256) {
l.hashes = append(l.hashes, h)
}
// Len return the length of the underlying hashes slice.
func (l *HeaderHashList) Len() int {
return len(l.hashes)
}
// Get returns the hash by the given index.
func (l *HeaderHashList) Get(i int) util.Uint256 {
if l.Len() < i {
return util.Uint256{}
}
return l.hashes[i]
}
// Last return the last hash in the HeaderHashList.
func (l *HeaderHashList) Last() util.Uint256 {
return l.hashes[l.Len()-1]
}
// Slice return a subslice of the underlying hashes.
// Subsliced from start to end.
// Example:
// headers := headerList.Slice(0, 2000)
func (l *HeaderHashList) Slice(start, end int) []util.Uint256 {
return l.hashes[start:end]
}
// WriteTo will write n underlying hashes to the given io.Writer
// starting from start.
func (l *HeaderHashList) Write(w io.Writer, start, n int) error {
if err := util.WriteVarUint(w, uint64(n)); err != nil {
return err
}
hashes := l.Slice(start, start+n)
for _, hash := range hashes {
if err := binary.Write(w, binary.LittleEndian, hash); err != nil {
return err
}
}
return nil
}

40
pkg/core/helper_test.go Normal file
View file

@ -0,0 +1,40 @@
package core
import (
"crypto/sha256"
"time"
"github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/util"
)
func newBlock(index uint32, txs ...*transaction.Transaction) *Block {
b := &Block{
BlockBase: BlockBase{
Version: 0,
PrevHash: sha256.Sum256([]byte("a")),
MerkleRoot: sha256.Sum256([]byte("b")),
Timestamp: uint32(time.Now().UTC().Unix()),
Index: index,
ConsensusData: 1111,
NextConsensus: util.Uint160{},
Script: &transaction.Witness{
VerificationScript: []byte{0x0},
InvocationScript: []byte{0x1},
},
},
Transactions: txs,
}
hash, err := b.createHash()
if err != nil {
panic(err)
}
b.hash = hash
return b
}
func newTX(t transaction.TXType) *transaction.Transaction {
return &transaction.Transaction{
Type: t,
}
}

View file

@ -21,6 +21,5 @@ func (s *LevelDBStore) writeBatch(batch Batch) error {
for k, v := range batch { for k, v := range batch {
b.Put(*k, v) b.Put(*k, v)
} }
return s.db.Write(b, nil) return s.db.Write(b, nil)
} }

View file

@ -1,6 +1,7 @@
package core package core
// MemoryStore is an in memory implementation of a BlockChainStorer. // MemoryStore is an in memory implementation of a BlockChainStorer
// that should only be used for testing.
type MemoryStore struct { type MemoryStore struct {
} }
@ -14,5 +15,10 @@ func (m *MemoryStore) write(key, value []byte) error {
} }
func (m *MemoryStore) writeBatch(batch Batch) error { func (m *MemoryStore) writeBatch(batch Batch) error {
for k, v := range batch {
if err := m.write(*k, v); err != nil {
return err
}
}
return nil return nil
} }

View file

@ -1,17 +1,13 @@
package core package core
import (
"bytes"
"encoding/binary"
)
type dataEntry uint8 type dataEntry uint8
func (e dataEntry) add(b []byte) []byte { func (e dataEntry) bytes() []byte {
dest := make([]byte, len(b)+1)
dest[0] = byte(e)
for i := 1; i < len(b); i++ {
dest[i] = b[i]
}
return dest
}
func (e dataEntry) toSlice() []byte {
return []byte{byte(e)} return []byte{byte(e)}
} }
@ -32,6 +28,21 @@ const (
preSYSVersion dataEntry = 0xf0 preSYSVersion dataEntry = 0xf0
) )
func makeEntryPrefixInt(e dataEntry, n int) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, n)
return makeEntryPrefix(e, buf.Bytes())
}
func makeEntryPrefix(e dataEntry, b []byte) []byte {
dest := make([]byte, len(b)+1)
dest[0] = byte(e)
for i := 1; i < len(b); i++ {
dest[i] = b[i]
}
return dest
}
// Store is anything that can persist and retrieve the blockchain. // Store is anything that can persist and retrieve the blockchain.
type Store interface { type Store interface {
write(k, v []byte) error write(k, v []byte) error
@ -39,5 +50,5 @@ type Store interface {
} }
// Batch is a data type used to store data for later batch operations // Batch is a data type used to store data for later batch operations
// by any Store. // that can be used by any Store interface implementation.
type Batch map[*[]byte][]byte type Batch map[*[]byte][]byte

View file

@ -10,7 +10,7 @@ import (
// Transaction is a process recorded in the NEO blockchain. // Transaction is a process recorded in the NEO blockchain.
type Transaction struct { type Transaction struct {
// The type of the transaction. // The type of the transaction.
Type TransactionType Type TXType
// The trading version which is currently 0. // The trading version which is currently 0.
Version uint8 Version uint8

View file

@ -1,26 +1,26 @@
package transaction package transaction
// TransactionType is the type of a transaction. // TXType is the type of a transaction.
type TransactionType uint8 type TXType uint8
// All processes in NEO system are recorded in transactions. // All processes in NEO system are recorded in transactions.
// There are several types of transactions. // There are several types of transactions.
const ( const (
MinerType TransactionType = 0x00 MinerType TXType = 0x00
IssueType TransactionType = 0x01 IssueType TXType = 0x01
ClaimType TransactionType = 0x02 ClaimType TXType = 0x02
EnrollmentType TransactionType = 0x20 EnrollmentType TXType = 0x20
VotingType TransactionType = 0x24 VotingType TXType = 0x24
RegisterType TransactionType = 0x40 RegisterType TXType = 0x40
ContractType TransactionType = 0x80 ContractType TXType = 0x80
StateType TransactionType = 0x90 StateType TXType = 0x90
AgencyType TransactionType = 0xb0 AgencyType TXType = 0xb0
PublishType TransactionType = 0xd0 PublishType TXType = 0xd0
InvocationType TransactionType = 0xd1 InvocationType TXType = 0xd1
) )
// String implements the stringer interface. // String implements the stringer interface.
func (t TransactionType) String() string { func (t TXType) String() string {
switch t { switch t {
case MinerType: case MinerType:
return "miner transaction" return "miner transaction"

12
pkg/core/util.go Normal file
View file

@ -0,0 +1,12 @@
package core
import "github.com/CityOfZion/neo-go/pkg/util"
// Utilities for quick bootstrapping blockchains. Normally we should
// create the genisis block. For now (to speed up development) we will add
// The hashes manually.
func GenesisHashPrivNet() util.Uint256 {
hash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099")
return hash
}

View file

@ -79,6 +79,7 @@ func Base58Encode(bytes []byte) string {
return encoded return encoded
} }
// Base58CheckDecode decodes the given string.
func Base58CheckDecode(s string) (b []byte, err error) { func Base58CheckDecode(s string) (b []byte, err error) {
b, err = Base58Decode(s) b, err = Base58Decode(s)
if err != nil { if err != nil {

View file

@ -61,26 +61,28 @@ type Message struct {
Payload payload.Payload Payload payload.Payload
} }
type commandType string // CommandType represents the type of a message command.
type CommandType string
// valid commands used to send between nodes. // valid commands used to send between nodes.
const ( const (
cmdVersion commandType = "version" CMDVersion CommandType = "version"
cmdVerack = "verack" CMDVerack CommandType = "verack"
cmdGetAddr = "getaddr" CMDGetAddr CommandType = "getaddr"
cmdAddr = "addr" CMDAddr CommandType = "addr"
cmdGetHeaders = "getheaders" CMDGetHeaders CommandType = "getheaders"
cmdHeaders = "headers" CMDHeaders CommandType = "headers"
cmdGetBlocks = "getblocks" CMDGetBlocks CommandType = "getblocks"
cmdInv = "inv" CMDInv CommandType = "inv"
cmdGetData = "getdata" CMDGetData CommandType = "getdata"
cmdBlock = "block" CMDBlock CommandType = "block"
cmdTX = "tx" CMDTX CommandType = "tx"
cmdConsensus = "consensus" CMDConsensus CommandType = "consensus"
cmdUnknown = "unknown" CMDUnknown CommandType = "unknown"
) )
func newMessage(magic NetMode, cmd commandType, p payload.Payload) *Message { // NewMessage returns a new message with the given payload.
func NewMessage(magic NetMode, cmd CommandType, p payload.Payload) *Message {
var ( var (
size uint32 size uint32
checksum []byte checksum []byte
@ -106,36 +108,36 @@ func newMessage(magic NetMode, cmd commandType, p payload.Payload) *Message {
} }
} }
// Converts the 12 byte command slice to a commandType. // CommandType converts the 12 byte command slice to a CommandType.
func (m *Message) commandType() commandType { func (m *Message) CommandType() CommandType {
cmd := cmdByteArrayToString(m.Command) cmd := cmdByteArrayToString(m.Command)
switch cmd { switch cmd {
case "version": case "version":
return cmdVersion return CMDVersion
case "verack": case "verack":
return cmdVerack return CMDVerack
case "getaddr": case "getaddr":
return cmdGetAddr return CMDGetAddr
case "addr": case "addr":
return cmdAddr return CMDAddr
case "getheaders": case "getheaders":
return cmdGetHeaders return CMDGetHeaders
case "headers": case "headers":
return cmdHeaders return CMDHeaders
case "getblocks": case "getblocks":
return cmdGetBlocks return CMDGetBlocks
case "inv": case "inv":
return cmdInv return CMDInv
case "getdata": case "getdata":
return cmdGetData return CMDGetData
case "block": case "block":
return cmdBlock return CMDBlock
case "tx": case "tx":
return cmdTX return CMDTX
case "consensus": case "consensus":
return cmdConsensus return CMDConsensus
default: default:
return cmdUnknown return CMDUnknown
} }
} }
@ -185,36 +187,35 @@ func (m *Message) decodePayload(r io.Reader) error {
return errChecksumMismatch return errChecksumMismatch
} }
//r = bytes.NewReader(buf)
r = buf r = buf
var p payload.Payload var p payload.Payload
switch m.commandType() { switch m.CommandType() {
case cmdVersion: case CMDVersion:
p = &payload.Version{} p = &payload.Version{}
if err := p.DecodeBinary(r); err != nil { if err := p.DecodeBinary(r); err != nil {
return err return err
} }
case cmdInv: case CMDInv:
p = &payload.Inventory{} p = &payload.Inventory{}
if err := p.DecodeBinary(r); err != nil { if err := p.DecodeBinary(r); err != nil {
return err return err
} }
case cmdAddr: case CMDAddr:
p = &payload.AddressList{} p = &payload.AddressList{}
if err := p.DecodeBinary(r); err != nil { if err := p.DecodeBinary(r); err != nil {
return err return err
} }
case cmdBlock: case CMDBlock:
p = &core.Block{} p = &core.Block{}
if err := p.DecodeBinary(r); err != nil { if err := p.DecodeBinary(r); err != nil {
return err return err
} }
case cmdGetHeaders: case CMDGetHeaders:
p = &payload.GetBlocks{} p = &payload.GetBlocks{}
if err := p.DecodeBinary(r); err != nil { if err := p.DecodeBinary(r); err != nil {
return err return err
} }
case cmdHeaders: case CMDHeaders:
p = &payload.Headers{} p = &payload.Headers{}
if err := p.DecodeBinary(r); err != nil { if err := p.DecodeBinary(r); err != nil {
return err return err
@ -242,7 +243,7 @@ func (m *Message) encode(w io.Writer) error {
// convert a command (string) to a byte slice filled with 0 bytes till // convert a command (string) to a byte slice filled with 0 bytes till
// size 12. // size 12.
func cmdToByteArray(cmd commandType) [cmdSize]byte { func cmdToByteArray(cmd CommandType) [cmdSize]byte {
cmdLen := len(cmd) cmdLen := len(cmd)
if cmdLen > cmdSize { if cmdLen > cmdSize {
panic("exceeded command max length of size 12") panic("exceeded command max length of size 12")

View file

@ -2,39 +2,33 @@ package network
import ( import (
"bytes" "bytes"
"reflect"
"testing" "testing"
"github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/stretchr/testify/assert"
) )
func TestMessageEncodeDecode(t *testing.T) { func TestMessageEncodeDecode(t *testing.T) {
m := newMessage(ModeTestNet, cmdVersion, nil) m := NewMessage(ModeTestNet, CMDVersion, nil)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if err := m.encode(buf); err != nil { if err := m.encode(buf); err != nil {
t.Error(err) t.Error(err)
} }
assert.Equal(t, len(buf.Bytes()), minMessageSize)
if n := len(buf.Bytes()); n < minMessageSize {
t.Fatalf("message should be at least %d bytes got %d", minMessageSize, n)
}
if n := len(buf.Bytes()); n > minMessageSize {
t.Fatalf("message without a payload should be exact %d bytes got %d", minMessageSize, n)
}
md := &Message{} md := &Message{}
if err := md.decode(buf); err != nil { if err := md.decode(buf); err != nil {
t.Error(err) t.Error(err)
} }
if !reflect.DeepEqual(m, md) { assert.Equal(t, m, md)
t.Errorf("both messages should be equal: %v != %v", m, md)
}
} }
func TestMessageEncodeDecodeWithVersion(t *testing.T) { func TestMessageEncodeDecodeWithVersion(t *testing.T) {
p := payload.NewVersion(12227, 2000, "/neo:2.6.0/", 0, true) var (
m := newMessage(ModeTestNet, cmdVersion, p) p = payload.NewVersion(12227, 2000, "/neo:2.6.0/", 0, true)
m = NewMessage(ModeTestNet, CMDVersion, p)
)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err := m.encode(buf); err != nil { if err := m.encode(buf); err != nil {
@ -45,15 +39,14 @@ func TestMessageEncodeDecodeWithVersion(t *testing.T) {
if err := mDecode.decode(buf); err != nil { if err := mDecode.decode(buf); err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, m, mDecode)
if !reflect.DeepEqual(m, mDecode) {
t.Fatalf("expected both messages to be equal %v and %v", m, mDecode)
}
} }
func TestMessageInvalidChecksum(t *testing.T) { func TestMessageInvalidChecksum(t *testing.T) {
p := payload.NewVersion(1111, 3000, "/NEO:2.6.0/", 0, true) var (
m := newMessage(ModeTestNet, cmdVersion, p) p = payload.NewVersion(1111, 3000, "/NEO:2.6.0/", 0, true)
m = NewMessage(ModeTestNet, CMDVersion, p)
)
m.Checksum = 1337 m.Checksum = 1337
buf := new(bytes.Buffer) buf := new(bytes.Buffer)

219
pkg/network/node.go Normal file
View file

@ -0,0 +1,219 @@
package network
import (
"errors"
"fmt"
"os"
"time"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
)
const (
protoVersion = 0
)
var protoTickInterval = 5 * time.Second
// Node represents the local node.
type Node struct {
// Config fields may not be modified while the server is running.
Config
logger log.Logger
server *Server
services uint64
bc *core.Blockchain
protoIn chan messageTuple
}
// messageTuple respresents a tuple that holds the message being
// send along with its peer.
type messageTuple struct {
peer Peer
msg *Message
}
func newNode(s *Server, cfg Config) *Node {
var startHash util.Uint256
if cfg.Net == ModePrivNet {
startHash = core.GenesisHashPrivNet()
}
bc := core.NewBlockchain(
core.NewMemoryStore(),
startHash,
)
logger := log.NewLogfmtLogger(os.Stderr)
logger = log.With(logger, "component", "node")
n := &Node{
Config: cfg,
protoIn: make(chan messageTuple),
server: s,
bc: bc,
logger: logger,
}
go n.handleMessages()
return n
}
func (n *Node) version() *payload.Version {
return payload.NewVersion(n.server.id, n.ListenTCP, n.UserAgent, 1, n.Relay)
}
func (n *Node) startProtocol(peer Peer) {
ticker := time.NewTicker(protoTickInterval).C
for {
select {
case <-ticker:
// Try to sync with the peer if his block height is higher then ours.
if peer.Version().StartHeight > n.bc.HeaderHeight() {
n.askMoreHeaders(peer)
}
// Only ask for more peers if the server has the capacity for it.
if n.server.hasCapacity() {
msg := NewMessage(n.Net, CMDGetAddr, nil)
peer.Send(msg)
}
case <-peer.Done():
return
}
}
}
// When a peer sends out his version we reply with verack after validating
// the version.
func (n *Node) handleVersionCmd(version *payload.Version, peer Peer) error {
msg := NewMessage(n.Net, CMDVerack, nil)
peer.Send(msg)
return nil
}
// handleInvCmd handles the forwarded inventory received from the peer.
// We will use the getdata message to get more details about the received
// inventory.
// note: if the server has Relay on false, inventory messages are not received.
func (n *Node) handleInvCmd(inv *payload.Inventory, peer Peer) error {
if !inv.Type.Valid() {
return fmt.Errorf("invalid inventory type received: %s", inv.Type)
}
if len(inv.Hashes) == 0 {
return errors.New("inventory has no hashes")
}
payload := payload.NewInventory(inv.Type, inv.Hashes)
peer.Send(NewMessage(n.Net, CMDGetData, payload))
return nil
}
// handleBlockCmd processes the received block received from its peer.
func (n *Node) handleBlockCmd(block *core.Block, peer Peer) error {
n.logger.Log(
"event", "block received",
"index", block.Index,
"hash", block.Hash(),
"tx", len(block.Transactions),
)
return n.bc.AddBlock(block)
}
// After a node sends out the getaddr message its receives a list of known peers
// in the network. handleAddrCmd processes that payload.
func (n *Node) handleAddrCmd(addressList *payload.AddressList, peer Peer) error {
addrs := make([]string, len(addressList.Addrs))
for i := 0; i < len(addrs); i++ {
addrs[i] = addressList.Addrs[i].Address.String()
}
n.server.connectToPeers(addrs...)
return nil
}
// The handleHeadersCmd will process the received headers from its peer.
// We call this in a routine cause we may block Peers Send() for to long.
func (n *Node) handleHeadersCmd(headers *payload.Headers, peer Peer) error {
go func(headers []*core.Header) {
if err := n.bc.AddHeaders(headers...); err != nil {
n.logger.Log("msg", "failed processing headers", "err", err)
return
}
// The peer will respond with a maximum of 2000 headers in one batch.
// We will ask one more batch here if needed. Eventually we will get synced
// due to the startProtocol routine that will ask headers every protoTick.
if n.bc.HeaderHeight() < peer.Version().StartHeight {
n.askMoreHeaders(peer)
}
}(headers.Hdrs)
return nil
}
// askMoreHeaders will send a getheaders message to the peer.
func (n *Node) askMoreHeaders(p Peer) {
start := []util.Uint256{n.bc.CurrentHeaderHash()}
payload := payload.NewGetBlocks(start, util.Uint256{})
p.Send(NewMessage(n.Net, CMDGetHeaders, payload))
}
// blockhain implements the Noder interface.
func (n *Node) blockchain() *core.Blockchain { return n.bc }
// handleProto implements the protoHandler interface.
func (n *Node) handleProto(msg *Message, p Peer) {
n.protoIn <- messageTuple{
msg: msg,
peer: p,
}
}
func (n *Node) handleMessages() {
for {
t := <-n.protoIn
var (
msg = t.msg
p = t.peer
err error
)
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
err = n.handleVersionCmd(version, p)
case CMDAddr:
addressList := msg.Payload.(*payload.AddressList)
err = n.handleAddrCmd(addressList, p)
case CMDInv:
inventory := msg.Payload.(*payload.Inventory)
err = n.handleInvCmd(inventory, p)
case CMDBlock:
block := msg.Payload.(*core.Block)
err = n.handleBlockCmd(block, p)
case CMDHeaders:
headers := msg.Payload.(*payload.Headers)
err = n.handleHeadersCmd(headers, p)
case CMDVerack:
// Only start the protocol if we got the version and verack
// received.
if p.Version() != nil {
go n.startProtocol(p)
}
case CMDUnknown:
err = errors.New("received non-protocol messgae")
}
if err != nil {
n.logger.Log(
"msg", "failed processing message",
"command", msg.CommandType,
"err", err,
)
}
}
}

7
pkg/network/node_test.go Normal file
View file

@ -0,0 +1,7 @@
package network
import "testing"
func TestHandleVersion(t *testing.T) {
}

View file

@ -1,90 +0,0 @@
package payload
import (
"encoding/binary"
"io"
"github.com/CityOfZion/neo-go/pkg/util"
)
// AddrWithTime payload
type AddrWithTime struct {
// Timestamp the node connected to the network.
Timestamp uint32
Services uint64
Addr util.Endpoint
}
// NewAddrWithTime return a pointer to AddrWithTime.
func NewAddrWithTime(addr util.Endpoint) *AddrWithTime {
return &AddrWithTime{
Timestamp: 1337,
Services: 1,
Addr: addr,
}
}
// Size implements the payload interface.
func (p *AddrWithTime) Size() uint32 {
return 30
}
// DecodeBinary implements the Payload interface.
func (p *AddrWithTime) DecodeBinary(r io.Reader) error {
err := binary.Read(r, binary.LittleEndian, &p.Timestamp)
err = binary.Read(r, binary.LittleEndian, &p.Services)
err = binary.Read(r, binary.BigEndian, &p.Addr.IP)
err = binary.Read(r, binary.BigEndian, &p.Addr.Port)
return err
}
// EncodeBinary implements the Payload interface.
func (p *AddrWithTime) EncodeBinary(w io.Writer) error {
err := binary.Write(w, binary.LittleEndian, p.Timestamp)
err = binary.Write(w, binary.LittleEndian, p.Services)
err = binary.Write(w, binary.BigEndian, p.Addr.IP)
err = binary.Write(w, binary.BigEndian, p.Addr.Port)
return err
}
// AddressList is a list with AddrWithTime.
type AddressList struct {
Addrs []*AddrWithTime
}
// DecodeBinary implements the Payload interface.
func (p *AddressList) DecodeBinary(r io.Reader) error {
listLen := util.ReadVarUint(r)
p.Addrs = make([]*AddrWithTime, listLen)
for i := 0; i < int(listLen); i++ {
addr := &AddrWithTime{}
if err := addr.DecodeBinary(r); err != nil {
return err
}
p.Addrs[i] = addr
}
return nil
}
// EncodeBinary implements the Payload interface.
func (p *AddressList) EncodeBinary(w io.Writer) error {
// Write the length of the slice
util.WriteVarUint(w, uint64(len(p.Addrs)))
for _, addr := range p.Addrs {
if err := addr.EncodeBinary(w); err != nil {
return err
}
}
return nil
}
// Size implements the Payloader interface.
func (p *AddressList) Size() uint32 {
return uint32(len(p.Addrs) * 30)
}

View file

@ -1,55 +0,0 @@
package payload
import (
"bytes"
"fmt"
"reflect"
"testing"
"github.com/CityOfZion/neo-go/pkg/util"
)
func TestEncodeDecodeAddr(t *testing.T) {
e, err := util.EndpointFromString("127.0.0.1:2000")
if err != nil {
t.Fatal(err)
}
addr := NewAddrWithTime(e)
buf := new(bytes.Buffer)
if err := addr.EncodeBinary(buf); err != nil {
t.Fatal(err)
}
addrDecode := &AddrWithTime{}
if err := addrDecode.DecodeBinary(buf); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(addr, addrDecode) {
t.Fatalf("expected both addr payloads to be equal: %v and %v", addr, addrDecode)
}
}
func TestEncodeDecodeAddressList(t *testing.T) {
var lenList uint8 = 4
addrList := &AddressList{make([]*AddrWithTime, lenList)}
for i := 0; i < int(lenList); i++ {
e, _ := util.EndpointFromString(fmt.Sprintf("127.0.0.1:200%d", i))
addrList.Addrs[i] = NewAddrWithTime(e)
}
buf := new(bytes.Buffer)
if err := addrList.EncodeBinary(buf); err != nil {
t.Fatal(err)
}
addrListDecode := &AddressList{}
if err := addrListDecode.DecodeBinary(buf); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(addrList, addrListDecode) {
t.Fatalf("expected both address list payloads to be equal: %v and %v", addrList, addrListDecode)
}
}

View file

@ -0,0 +1,85 @@
package payload
import (
"encoding/binary"
"io"
"time"
"github.com/CityOfZion/neo-go/pkg/util"
)
// AddressAndTime payload.
type AddressAndTime struct {
Timestamp uint32
Services uint64
Address util.Endpoint
}
// NewAddressAndTime creates a new AddressAndTime object.
func NewAddressAndTime(e util.Endpoint, t time.Time) *AddressAndTime {
return &AddressAndTime{
Timestamp: uint32(t.UTC().Unix()),
Services: 1,
Address: e,
}
}
// DecodeBinary implements the Payload interface.
func (p *AddressAndTime) DecodeBinary(r io.Reader) error {
if err := binary.Read(r, binary.LittleEndian, &p.Timestamp); err != nil {
return err
}
if err := binary.Read(r, binary.LittleEndian, &p.Services); err != nil {
return err
}
if err := binary.Read(r, binary.BigEndian, &p.Address.IP); err != nil {
return err
}
return binary.Read(r, binary.BigEndian, &p.Address.Port)
}
// EncodeBinary implements the Payload interface.
func (p *AddressAndTime) EncodeBinary(w io.Writer) error {
if err := binary.Write(w, binary.LittleEndian, p.Timestamp); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, p.Services); err != nil {
return err
}
if err := binary.Write(w, binary.BigEndian, p.Address.IP); err != nil {
return err
}
return binary.Write(w, binary.BigEndian, p.Address.Port)
}
// AddressList is a list with AddrAndTime.
type AddressList struct {
Addrs []*AddressAndTime
}
// DecodeBinary implements the Payload interface.
func (p *AddressList) DecodeBinary(r io.Reader) error {
listLen := util.ReadVarUint(r)
p.Addrs = make([]*AddressAndTime, listLen)
for i := 0; i < int(listLen); i++ {
p.Addrs[i] = &AddressAndTime{}
if err := p.Addrs[i].DecodeBinary(r); err != nil {
return err
}
}
return nil
}
// EncodeBinary implements the Payload interface.
func (p *AddressList) EncodeBinary(w io.Writer) error {
if err := util.WriteVarUint(w, uint64(len(p.Addrs))); err != nil {
return err
}
for _, addr := range p.Addrs {
if err := addr.EncodeBinary(w); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,51 @@
package payload
import (
"bytes"
"fmt"
"testing"
"time"
"github.com/CityOfZion/neo-go/pkg/util"
"github.com/stretchr/testify/assert"
)
func TestEncodeDecodeAddress(t *testing.T) {
var (
e = util.NewEndpoint("127.0.0.1:2000")
addr = NewAddressAndTime(e, time.Now())
buf = new(bytes.Buffer)
)
if err := addr.EncodeBinary(buf); err != nil {
t.Fatal(err)
}
addrDecode := &AddressAndTime{}
if err := addrDecode.DecodeBinary(buf); err != nil {
t.Fatal(err)
}
assert.Equal(t, addr, addrDecode)
}
func TestEncodeDecodeAddressList(t *testing.T) {
var lenList uint8 = 4
addrList := &AddressList{make([]*AddressAndTime, lenList)}
for i := 0; i < int(lenList); i++ {
e := util.NewEndpoint(fmt.Sprintf("127.0.0.1:200%d", i))
addrList.Addrs[i] = NewAddressAndTime(e, time.Now())
}
buf := new(bytes.Buffer)
if err := addrList.EncodeBinary(buf); err != nil {
t.Fatal(err)
}
addrListDecode := &AddressList{}
if err := addrListDecode.DecodeBinary(buf); err != nil {
t.Fatal(err)
}
assert.Equal(t, addrList, addrListDecode)
}

View file

@ -37,6 +37,5 @@ func (p *Headers) EncodeBinary(w io.Writer) error {
return err return err
} }
} }
return nil return nil
} }

View file

@ -2,11 +2,11 @@ package payload
import ( import (
"bytes" "bytes"
"reflect"
"testing" "testing"
"github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/stretchr/testify/assert"
) )
func TestHeadersEncodeDecode(t *testing.T) { func TestHeadersEncodeDecode(t *testing.T) {
@ -50,7 +50,9 @@ func TestHeadersEncodeDecode(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(headers, headersDecode) { for i := 0; i < len(headers.Hdrs); i++ {
t.Fatalf("expected both header payload to be equal %+v and %+v", headers, headersDecode) assert.Equal(t, headers.Hdrs[i].Version, headersDecode.Hdrs[i].Version)
assert.Equal(t, headers.Hdrs[i].Index, headersDecode.Hdrs[i].Index)
assert.Equal(t, headers.Hdrs[i].Script, headersDecode.Hdrs[i].Script)
} }
} }

View file

@ -35,8 +35,8 @@ func (i InventoryType) Valid() bool {
// List of valid InventoryTypes. // List of valid InventoryTypes.
const ( const (
BlockType InventoryType = 0x01 // 1 BlockType InventoryType = 0x01 // 1
TXType = 0x02 // 2 TXType InventoryType = 0x02 // 2
ConsensusType = 0xe0 // 224 ConsensusType InventoryType = 0xe0 // 224
) )
// Inventory payload // Inventory payload

View file

@ -44,36 +44,63 @@ func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version {
// DecodeBinary implements the Payload interface. // DecodeBinary implements the Payload interface.
func (p *Version) DecodeBinary(r io.Reader) error { func (p *Version) DecodeBinary(r io.Reader) error {
err := binary.Read(r, binary.LittleEndian, &p.Version) if err := binary.Read(r, binary.LittleEndian, &p.Version); err != nil {
err = binary.Read(r, binary.LittleEndian, &p.Services) return err
err = binary.Read(r, binary.LittleEndian, &p.Timestamp) }
err = binary.Read(r, binary.LittleEndian, &p.Port) if err := binary.Read(r, binary.LittleEndian, &p.Services); err != nil {
err = binary.Read(r, binary.LittleEndian, &p.Nonce) return err
}
if err := binary.Read(r, binary.LittleEndian, &p.Timestamp); err != nil {
return err
}
if err := binary.Read(r, binary.LittleEndian, &p.Port); err != nil {
return err
}
if err := binary.Read(r, binary.LittleEndian, &p.Nonce); err != nil {
return err
}
var lenUA uint8 var lenUA uint8
err = binary.Read(r, binary.LittleEndian, &lenUA) if err := binary.Read(r, binary.LittleEndian, &lenUA); err != nil {
return err
}
p.UserAgent = make([]byte, lenUA) p.UserAgent = make([]byte, lenUA)
err = binary.Read(r, binary.LittleEndian, &p.UserAgent) if err := binary.Read(r, binary.LittleEndian, &p.UserAgent); err != nil {
return err
err = binary.Read(r, binary.LittleEndian, &p.StartHeight) }
err = binary.Read(r, binary.LittleEndian, &p.Relay) if err := binary.Read(r, binary.LittleEndian, &p.StartHeight); err != nil {
return err
return err }
return binary.Read(r, binary.LittleEndian, &p.Relay)
} }
// EncodeBinary implements the Payload interface. // EncodeBinary implements the Payload interface.
func (p *Version) EncodeBinary(w io.Writer) error { func (p *Version) EncodeBinary(w io.Writer) error {
err := binary.Write(w, binary.LittleEndian, p.Version) if err := binary.Write(w, binary.LittleEndian, p.Version); err != nil {
err = binary.Write(w, binary.LittleEndian, p.Services) return err
err = binary.Write(w, binary.LittleEndian, p.Timestamp) }
err = binary.Write(w, binary.LittleEndian, p.Port) if err := binary.Write(w, binary.LittleEndian, p.Services); err != nil {
err = binary.Write(w, binary.LittleEndian, p.Nonce) return err
err = binary.Write(w, binary.LittleEndian, uint8(len(p.UserAgent))) }
err = binary.Write(w, binary.LittleEndian, p.UserAgent) if err := binary.Write(w, binary.LittleEndian, p.Timestamp); err != nil {
err = binary.Write(w, binary.LittleEndian, p.StartHeight) return err
err = binary.Write(w, binary.LittleEndian, p.Relay) }
if err := binary.Write(w, binary.LittleEndian, p.Port); err != nil {
return err return err
}
if err := binary.Write(w, binary.LittleEndian, p.Nonce); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, uint8(len(p.UserAgent))); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, p.UserAgent); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, p.StartHeight); err != nil {
return err
}
return binary.Write(w, binary.LittleEndian, p.Relay)
} }
// Size implements the payloader interface. // Size implements the payloader interface.

View file

@ -5,48 +5,12 @@ import (
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
) )
// Peer is the local representation of a remote node. It's an interface that may // A Peer is the local representation of a remote peer.
// be backed by any concrete transport: local, HTTP, tcp. // It's an interface that may be backed by any concrete
// transport.
type Peer interface { type Peer interface {
id() uint32 Version() *payload.Version
addr() util.Endpoint Endpoint() util.Endpoint
disconnect() Send(*Message)
Send(*Message) error Done() chan struct{}
version() *payload.Version
} }
// LocalPeer is the simplest kind of peer, mapped to a server in the
// same process-space.
type LocalPeer struct {
s *Server
nonce uint32
endpoint util.Endpoint
pVersion *payload.Version
}
// NewLocalPeer return a LocalPeer.
func NewLocalPeer(s *Server) *LocalPeer {
e, _ := util.EndpointFromString("1.1.1.1:1111")
return &LocalPeer{endpoint: e, s: s}
}
func (p *LocalPeer) Send(msg *Message) error {
switch msg.commandType() {
case cmdVersion:
version := msg.Payload.(*payload.Version)
return p.s.handleVersionCmd(version, p)
case cmdGetAddr:
return p.s.handleGetaddrCmd(msg, p)
default:
return nil
}
}
// Version implements the Peer interface.
func (p *LocalPeer) version() *payload.Version {
return p.pVersion
}
func (p *LocalPeer) id() uint32 { return p.nonce }
func (p *LocalPeer) addr() util.Endpoint { return p.endpoint }
func (p *LocalPeer) disconnect() {}

22
pkg/network/protocol.go Normal file
View file

@ -0,0 +1,22 @@
package network
import (
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/network/payload"
)
// A ProtoHandler is an interface that abstract the implementation
// of the NEO protocol.
type ProtoHandler interface {
version() *payload.Version
handleProto(*Message, Peer)
}
type protoHandleFunc func(*Message, Peer)
// Noder is anything that implements the NEO protocol
// and can return the Blockchain object.
type Noder interface {
ProtoHandler
blockchain() *core.Blockchain
}

View file

@ -1,129 +0,0 @@
package network
import (
"encoding/json"
"fmt"
"net/http"
)
const (
rpcPortMainNet = 20332
rpcPortTestNet = 10332
rpcVersion = "2.0"
// error response messages
methodNotFound = "Method not found"
parseError = "Parse error"
)
// Each NEO node has a set of optional APIs for accessing blockchain
// data and making things easier for development of blockchain apps.
// APIs are provided via JSON-RPC , comm at bottom layer is with http/https protocol.
// listenHTTP creates an ingress bridge from the outside world to the passed
// server, by installing handlers for all the necessary RPCs to the passed mux.
func listenHTTP(s *Server, port int) {
api := &API{s}
p := fmt.Sprintf(":%d", port)
s.logger.Printf("serving RPC on %d", port)
s.logger.Printf("%s", http.ListenAndServe(p, api))
}
// API serves JSON-RPC.
type API struct {
s *Server
}
func (s *API) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Official nodes respond a parse error if the method is not POST.
// Instead of returning a decent response for this, let's do the same.
if r.Method != "POST" {
writeError(w, 0, 0, parseError)
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, 0, 0, parseError)
return
}
defer r.Body.Close()
if req.Version != rpcVersion {
writeJSON(w, http.StatusBadRequest, nil)
return
}
switch req.Method {
case "getconnectioncount":
if err := s.getConnectionCount(w, &req); err != nil {
writeError(w, 0, 0, parseError)
return
}
case "getblockcount":
case "getbestblockhash":
default:
writeError(w, 0, 0, methodNotFound)
}
}
// This is an Example on how we could handle incomming RPC requests.
func (s *API) getConnectionCount(w http.ResponseWriter, req *Request) error {
count := s.s.peerCount()
resp := ConnectionCountResponse{
Version: rpcVersion,
Result: count,
ID: 1,
}
return writeJSON(w, http.StatusOK, resp)
}
// writeError returns a JSON error with given parameters. All error HTTP
// status codes are 200. According to the official API.
func writeError(w http.ResponseWriter, id, code int, msg string) error {
resp := RequestError{
Version: rpcVersion,
ID: id,
Error: Error{
Code: code,
Message: msg,
},
}
return writeJSON(w, http.StatusOK, resp)
}
func writeJSON(w http.ResponseWriter, status int, v interface{}) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
return json.NewEncoder(w).Encode(v)
}
// Request is an object received through JSON-RPC from the client.
type Request struct {
Version string `json:"jsonrpc"`
Method string `json:"method"`
Params []string `json:"params"`
ID int `json:"id"`
}
// ConnectionCountResponse ..
type ConnectionCountResponse struct {
Version string `json:"jsonrpc"`
Result int `json:"result"`
ID int `json:"id"`
}
// RequestError ..
type RequestError struct {
Version string `json:"jsonrpc"`
ID int `json:"id"`
Error Error `json:"error"`
}
// Error holds information about an RCP error.
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
}

View file

@ -1,354 +1,282 @@
package network package network
import ( import (
"context"
"errors"
"fmt" "fmt"
"log"
"net" "net"
"os" "os"
"sync" "text/tabwriter"
"time" "time"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
) )
const ( const (
// node version // node version
version = "2.6.0" version = "2.6.0"
// official ports according to the protocol. // official ports according to the protocol.
portMainNet = 10333 portMainNet = 10333
portTestNet = 20333 portTestNet = 20333
maxPeers = 50 maxPeers = 50
) )
type messageTuple struct { var dialTimeout = 4 * time.Second
peer Peer
msg *Message // Config holds the server configuration.
type Config struct {
// MaxPeers it the maximum numbers of peers that can
// be connected to the server.
MaxPeers int
// The user agent of the server.
UserAgent string
// The listen address of the TCP server.
ListenTCP uint16
// The listen address of the RPC server.
ListenRPC uint16
// The network mode this server will operate on.
// ModePrivNet docker private network.
// ModeTestNet NEO test network.
// ModeMainNet NEO main network.
Net NetMode
// Relay determins whether the server is forwarding its inventory.
Relay bool
// Seeds are a list of initial nodes used to establish connectivity.
Seeds []string
// Maximum duration a single dial may take.
DialTimeout time.Duration
} }
// Server is the representation of a full working NEO TCP node. // Server manages all incoming peer connections.
type Server struct { type Server struct {
logger *log.Logger // Config fields may not be modified while the server is running.
// id of the server Config
// Proto is just about anything that can handle the NEO protocol.
// In production enviroments the ProtoHandler is mostly the local node.
proto ProtoHandler
// Unique id of this server.
id uint32 id uint32
// the port the TCP listener is listening on.
port uint16 logger log.Logger
// userAgent of the server.
userAgent string
// The "magic" mode the server is currently running on.
// This can either be 0x00746e41 or 0x74746e41 for main or test net.
// Or 56753 to work with the docker privnet.
net NetMode
// map that holds all connected peers to this server.
peers map[Peer]bool
// channel for handling new registerd peers.
register chan Peer
// channel for safely removing and disconnecting peers.
unregister chan Peer
// channel for coordinating messages.
message chan messageTuple
// channel used to gracefull shutdown the server.
quit chan struct{}
// Whether this server will receive and forward messages.
relay bool
// TCP listener of the server
listener net.Listener listener net.Listener
// channel for safely responding the number of current connected peers.
peerCountCh chan peerCount register chan Peer
// a list of hashes that unregister chan Peer
knownHashes protectedHashmap
// The blockchain. badAddrOp chan func(map[string]bool)
bc *core.Blockchain badAddrOpDone chan struct{}
peerOp chan func(map[Peer]bool)
peerOpDone chan struct{}
quit chan struct{}
} }
// TODO: Maybe util is a better place for such data types. // NewServer returns a new Server object created from the
type protectedHashmap struct { // given config.
*sync.RWMutex func NewServer(cfg Config) *Server {
hashes map[util.Uint256]bool if cfg.MaxPeers == 0 {
} cfg.MaxPeers = maxPeers
func (m protectedHashmap) add(h util.Uint256) bool {
m.Lock()
defer m.Unlock()
if _, ok := m.hashes[h]; !ok {
m.hashes[h] = true
return true
} }
return false if cfg.Net == 0 {
} cfg.Net = ModeTestNet
func (m protectedHashmap) remove(h util.Uint256) bool {
m.Lock()
defer m.Unlock()
if _, ok := m.hashes[h]; ok {
delete(m.hashes, h)
return true
} }
return false if cfg.DialTimeout == 0 {
} cfg.DialTimeout = dialTimeout
func (m protectedHashmap) has(h util.Uint256) bool {
m.RLock()
defer m.RUnlock()
_, ok := m.hashes[h]
return ok
}
// NewServer returns a pointer to a new server.
func NewServer(net NetMode) *Server {
logger := log.New(os.Stdout, "[NEO SERVER] :: ", 0)
if net != ModeTestNet && net != ModeMainNet && net != ModePrivNet {
logger.Fatalf("invalid network mode %d", net)
} }
// For now I will hard code a genesis block of the docker privnet container. logger := log.NewLogfmtLogger(os.Stderr)
startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099") logger = log.With(logger, "component", "server")
s := &Server{ s := &Server{
id: util.RandUint32(1111111, 9999999), Config: cfg,
userAgent: fmt.Sprintf("/NEO:%s/", version), logger: logger,
logger: logger, id: util.RandUint32(1000000, 9999999),
peers: make(map[Peer]bool), quit: make(chan struct{}, 1),
register: make(chan Peer), register: make(chan Peer),
unregister: make(chan Peer), unregister: make(chan Peer),
message: make(chan messageTuple), badAddrOp: make(chan func(map[string]bool)),
relay: true, // currently relay is not handled. badAddrOpDone: make(chan struct{}),
net: net, peerOp: make(chan func(map[Peer]bool)),
quit: make(chan struct{}), peerOpDone: make(chan struct{}),
peerCountCh: make(chan peerCount),
bc: core.NewBlockchain(core.NewMemoryStore(), logger, startHash),
} }
s.proto = newNode(s, cfg)
return s return s
} }
// Start run's the server. func (s *Server) createListener() error {
// TODO: server should be initialized with a config. ln, err := net.Listen("tcp", fmt.Sprintf(":%d", s.ListenTCP))
func (s *Server) Start(opts StartOpts) {
s.port = uint16(opts.TCP)
fmt.Println(logo())
fmt.Println(string(s.userAgent))
fmt.Println("")
s.logger.Printf("NET: %s - TCP: %d - RELAY: %v - ID: %d",
s.net, int(s.port), s.relay, s.id)
go listenTCP(s, opts.TCP)
if opts.RPC > 0 {
go listenHTTP(s, opts.RPC)
}
if len(opts.Seeds) > 0 {
connectToSeeds(s, opts.Seeds)
}
s.loop()
}
// Stop the server, attemping a gracefull shutdown.
func (s *Server) Stop() { s.quit <- struct{}{} }
// shutdown the server, disconnecting all peers.
func (s *Server) shutdown() {
s.logger.Println("attemping a quitefull shutdown.")
s.listener.Close()
// disconnect and remove all connected peers.
for peer := range s.peers {
peer.disconnect()
}
}
func (s *Server) loop() {
for {
select {
// When a new connection is been established, (by this server or remote node)
// its peer will be received on this channel.
// Any peer registration must happen via this channel.
case peer := <-s.register:
if len(s.peers) < maxPeers {
s.logger.Printf("peer registered from address %s", peer.addr())
s.peers[peer] = true
if err := s.handlePeerConnected(peer); err != nil {
s.logger.Printf("failed handling peer connection: %s", err)
peer.disconnect()
}
}
// unregister safely deletes a peer. For disconnecting peers use the
// disconnect() method on the peer, it will call unregister and terminates its routines.
case peer := <-s.unregister:
if _, ok := s.peers[peer]; ok {
delete(s.peers, peer)
s.logger.Printf("peer %s disconnected", peer.addr())
}
case t := <-s.peerCountCh:
t.count <- len(s.peers)
case <-s.quit:
s.shutdown()
}
}
}
// When a new peer is connected we send our version.
// No further communication should be made before both sides has received
// the versions of eachother.
func (s *Server) handlePeerConnected(p Peer) error {
// TODO: get the blockheight of this server once core implemented this.
payload := payload.NewVersion(s.id, s.port, s.userAgent, s.bc.HeaderHeight(), s.relay)
msg := newMessage(s.net, cmdVersion, payload)
return p.Send(msg)
}
func (s *Server) handleVersionCmd(version *payload.Version, p Peer) error {
if s.id == version.Nonce {
return errors.New("identical nonce")
}
if p.addr().Port != version.Port {
return fmt.Errorf("port mismatch: %d and %d", version.Port, p.addr().Port)
}
return p.Send(
newMessage(s.net, cmdVerack, nil),
)
}
func (s *Server) handleGetaddrCmd(msg *Message, p Peer) error {
return nil
}
// The node can broadcast the object information it owns by this message.
// The message can be sent automatically or can be used to answer getbloks messages.
func (s *Server) handleInvCmd(inv *payload.Inventory, p Peer) error {
if !inv.Type.Valid() {
return fmt.Errorf("invalid inventory type %s", inv.Type)
}
if len(inv.Hashes) == 0 {
return errors.New("inventory should have at least 1 hash got 0")
}
// todo: only grab the hashes that we dont know.
payload := payload.NewInventory(inv.Type, inv.Hashes)
resp := newMessage(s.net, cmdGetData, payload)
return p.Send(resp)
}
// handleBlockCmd processes the received block.
func (s *Server) handleBlockCmd(block *core.Block, p Peer) error {
hash, err := block.Hash()
if err != nil { if err != nil {
return err return err
} }
s.listener = ln
s.logger.Printf("new block: index %d hash %s", block.Index, hash)
return nil return nil
} }
// After receiving the getaddr message, the node returns an addr message as response func (s *Server) listenTCP() {
// and provides information about the known nodes on the network. for {
func (s *Server) handleAddrCmd(addrList *payload.AddressList, p Peer) error { conn, err := s.listener.Accept()
for _, addr := range addrList.Addrs { if err != nil {
if !s.peerAlreadyConnected(addr.Addr) { s.logger.Log("msg", "conn read error", "err", err)
// TODO: this is not transport abstracted. break
go connectToRemoteNode(s, addr.Addr.String()) }
go s.setupConnection(conn)
}
s.Quit()
}
func (s *Server) setupConnection(conn net.Conn) {
if !s.hasCapacity() {
s.logger.Log("msg", "server reached maximum capacity")
return
}
p := NewTCPPeer(conn, s.proto.handleProto)
s.register <- p
if err := p.run(); err != nil {
s.unregister <- p
}
}
func (s *Server) connectToPeers(addrs ...string) {
for _, addr := range addrs {
if s.hasCapacity() && s.canConnectWith(addr) {
go func(addr string) {
conn, err := net.DialTimeout("tcp", addr, s.DialTimeout)
if err != nil {
s.badAddrOp <- func(badAddrs map[string]bool) {
badAddrs[addr] = true
}
<-s.badAddrOpDone
return
}
go s.setupConnection(conn)
}(addr)
} }
} }
return nil
} }
// Handle the headers received from the remote after we asked for headers with the func (s *Server) canConnectWith(addr string) bool {
// "getheaders" message. canConnect := true
func (s *Server) handleHeadersCmd(headers *payload.Headers, p Peer) error { s.peerOp <- func(peers map[Peer]bool) {
// Set a deadline for adding headers? for peer := range peers {
go func(ctx context.Context, headers []*core.Header) { if peer.Endpoint().String() == addr {
if err := s.bc.AddHeaders(headers...); err != nil { canConnect = false
s.logger.Printf("failed to add headers: %s", err) break
return
}
// Ask more headers if we are not in sync with the peer.
if s.bc.HeaderHeight() < p.version().StartHeight {
if err := s.askMoreHeaders(p); err != nil {
s.logger.Printf("getheaders RPC failed: %s", err)
return
} }
} }
}(context.TODO(), headers.Hdrs) }
<-s.peerOpDone
if !canConnect {
return false
}
return nil s.badAddrOp <- func(badAddrs map[string]bool) {
_, ok := badAddrs[addr]
canConnect = !ok
}
<-s.badAddrOpDone
return canConnect
} }
// Ask the peer for more headers We use the current block hash as start. func (s *Server) hasCapacity() bool {
func (s *Server) askMoreHeaders(p Peer) error { return s.PeerCount() != s.MaxPeers
start := []util.Uint256{s.bc.CurrentHeaderHash()}
payload := payload.NewGetBlocks(start, util.Uint256{})
msg := newMessage(s.net, cmdGetHeaders, payload)
return p.Send(msg)
} }
// check if the addr is already connected to the server. func (s *Server) sendVersion(peer Peer) {
func (s *Server) peerAlreadyConnected(addr net.Addr) bool { peer.Send(NewMessage(s.Net, CMDVersion, s.proto.version()))
// TODO: Dont try to connect with ourselfs. }
for peer := range s.peers {
if peer.addr().String() == addr.String() { func (s *Server) run() {
return true var (
ticker = time.NewTicker(30 * time.Second).C
peers = make(map[Peer]bool)
badAddrs = make(map[string]bool)
)
for {
select {
case op := <-s.badAddrOp:
op(badAddrs)
s.badAddrOpDone <- struct{}{}
case op := <-s.peerOp:
op(peers)
s.peerOpDone <- struct{}{}
case p := <-s.register:
peers[p] = true
// When a new peer connection is established, we send
// out our version immediately.
s.sendVersion(p)
s.logger.Log("event", "peer connected", "endpoint", p.Endpoint())
case p := <-s.unregister:
delete(peers, p)
s.logger.Log("event", "peer disconnected", "endpoint", p.Endpoint())
case <-ticker:
s.printState()
case <-s.quit:
return
} }
} }
return false
} }
// TODO: Quit this routine if the peer is disconnected. // PeerCount returns the number of current connected peers.
func (s *Server) startProtocol(p Peer) { func (s *Server) PeerCount() (n int) {
if s.bc.HeaderHeight() < p.version().StartHeight { s.peerOp <- func(peers map[Peer]bool) {
s.askMoreHeaders(p) n = len(peers)
}
for {
getaddrMsg := newMessage(s.net, cmdGetAddr, nil)
p.Send(getaddrMsg)
time.Sleep(30 * time.Second)
} }
<-s.peerOpDone
return
} }
type peerCount struct { func (s *Server) Start() error {
count chan int fmt.Println(logo())
} fmt.Println("")
s.printConfiguration()
// peerCount returns the number of connected peers to this server. if err := s.createListener(); err != nil {
func (s *Server) peerCount() int { return err
ch := peerCount{
count: make(chan int),
} }
s.peerCountCh <- ch go s.run()
go s.listenTCP()
return <-ch.count go s.connectToPeers(s.Seeds...)
select {}
} }
// StartOpts holds the server configuration. func (s *Server) Quit() {
type StartOpts struct { s.quit <- struct{}{}
// tcp port }
TCP int
// slice of peer addresses the server will connect to func (s *Server) printState() {
Seeds []string w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0)
// JSON-RPC port. If 0 no RPC handler will be attached. fmt.Fprintf(w, "connected peers:\t%d/%d\n", s.PeerCount(), s.MaxPeers)
RPC int w.Flush()
}
func (s *Server) printConfiguration() {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0)
fmt.Fprintf(w, "user agent:\t%s\n", s.UserAgent)
fmt.Fprintf(w, "id:\t%d\n", s.id)
fmt.Fprintf(w, "network:\t%s\n", s.Net)
fmt.Fprintf(w, "listen TCP:\t%d\n", s.ListenTCP)
fmt.Fprintf(w, "listen RPC:\t%d\n", s.ListenRPC)
fmt.Fprintf(w, "relay:\t%v\n", s.Relay)
fmt.Fprintf(w, "max peers:\t%d\n", s.MaxPeers)
chainer := s.proto.(Noder)
fmt.Fprintf(w, "current height:\t%d\n", chainer.blockchain().HeaderHeight())
fmt.Fprintln(w, "")
w.Flush()
} }
func logo() string { func logo() string {

View file

@ -1,60 +1,86 @@
package network package network
import ( import (
"os"
"testing" "testing"
"github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
"github.com/stretchr/testify/assert"
) )
// TODO this should be moved to localPeer test. func TestRegisterPeer(t *testing.T) {
s := newTestServer()
go s.run()
func TestHandleVersionFailWrongPort(t *testing.T) { assert.NotZero(t, s.id)
s := NewServer(ModePrivNet) assert.Zero(t, s.PeerCount())
go s.loop()
p := NewLocalPeer(s) lenPeers := 10
for i := 0; i < lenPeers; i++ {
s.register <- newTestPeer()
}
assert.Equal(t, lenPeers, s.PeerCount())
}
version := payload.NewVersion(1337, 1, "/NEO:0.0.0/", 0, true) func TestUnregisterPeer(t *testing.T) {
if err := s.handleVersionCmd(version, p); err == nil { s := newTestServer()
t.Fatal("expected error got nil") go s.run()
peer := newTestPeer()
s.register <- peer
s.register <- newTestPeer()
s.register <- newTestPeer()
assert.Equal(t, 3, s.PeerCount())
s.unregister <- peer
assert.Equal(t, 2, s.PeerCount())
}
type testNode struct{}
func (t testNode) version() *payload.Version {
return &payload.Version{}
}
func (t testNode) handleProto(msg *Message, p Peer) {}
func newTestServer() *Server {
return &Server{
logger: log.NewLogfmtLogger(os.Stderr),
id: util.RandUint32(1000000, 9999999),
quit: make(chan struct{}, 1),
register: make(chan Peer),
unregister: make(chan Peer),
badAddrOp: make(chan func(map[string]bool)),
badAddrOpDone: make(chan struct{}),
peerOp: make(chan func(map[Peer]bool)),
peerOpDone: make(chan struct{}),
proto: testNode{},
} }
} }
func TestHandleVersionFailIdenticalNonce(t *testing.T) { type testPeer struct {
s := NewServer(ModePrivNet) done chan struct{}
go s.loop() }
p := NewLocalPeer(s) func newTestPeer() testPeer {
return testPeer{
version := payload.NewVersion(s.id, 1, "/NEO:0.0.0/", 0, true) done: make(chan struct{}),
if err := s.handleVersionCmd(version, p); err == nil {
t.Fatal("expected error got nil")
} }
} }
func TestHandleVersion(t *testing.T) { func (p testPeer) Version() *payload.Version {
s := NewServer(ModePrivNet) return &payload.Version{}
go s.loop()
p := NewLocalPeer(s)
version := payload.NewVersion(1337, p.addr().Port, "/NEO:0.0.0/", 0, true)
if err := s.handleVersionCmd(version, p); err != nil {
t.Fatal(err)
}
} }
func TestHandleAddrCmd(t *testing.T) { func (p testPeer) Endpoint() util.Endpoint {
// todo return util.Endpoint{}
} }
func TestHandleGetAddrCmd(t *testing.T) { func (p testPeer) Send(msg *Message) {}
// todo
}
func TestHandleInv(t *testing.T) { func (p testPeer) Done() chan struct{} {
// todo return p.done
}
func TestHandleBlockCmd(t *testing.T) {
// todo
} }

View file

@ -1,254 +0,0 @@
package network
import (
"bytes"
"errors"
"fmt"
"net"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
)
func listenTCP(s *Server, port int) error {
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return err
}
s.listener = ln
for {
conn, err := ln.Accept()
if err != nil {
return err
}
go handleConnection(s, conn)
}
}
func connectToRemoteNode(s *Server, address string) {
conn, err := net.Dial("tcp", address)
if err != nil {
s.logger.Printf("failed to connect to remote node %s", address)
if conn != nil {
conn.Close()
}
return
}
go handleConnection(s, conn)
}
func connectToSeeds(s *Server, addrs []string) {
for _, addr := range addrs {
go connectToRemoteNode(s, addr)
}
}
func handleConnection(s *Server, conn net.Conn) {
peer := NewTCPPeer(conn, s)
s.register <- peer
// remove the peer from connected peers and cleanup the connection.
defer func() {
peer.disconnect()
}()
// Start a goroutine that will handle all outgoing messages.
go peer.writeLoop()
// Start a goroutine that will handle all incomming messages.
go handleMessage(s, peer)
// Read from the connection and decode it into a Message ready for processing.
for {
msg := &Message{}
if err := msg.decode(conn); err != nil {
s.logger.Printf("decode error: %s", err)
break
}
peer.receive <- msg
}
}
// handleMessage multiplexes the message received from a TCP connection to a server command.
func handleMessage(s *Server, p *TCPPeer) {
var err error
for {
msg := <-p.receive
command := msg.commandType()
// s.logger.Printf("IN :: %d :: %s :: %v", p.id(), command, msg)
switch command {
case cmdVersion:
version := msg.Payload.(*payload.Version)
if err = s.handleVersionCmd(version, p); err != nil {
break
}
p.nonce = version.Nonce
p.pVersion = version
// When a node receives a connection request, it declares its version immediately.
// There will be no other communication until both sides are getting versions of each other.
// When a node receives the version message, it replies to a verack as a response immediately.
// NOTE: The current official NEO nodes dont mimic this behaviour. There is small chance that the
// official nodes will not respond directly with a verack after we sended our version.
// is this a bug? - anthdm 02/02/2018
msgVerack := <-p.receive
if msgVerack.commandType() != cmdVerack {
err = errors.New("expected verack after sended out version")
break
}
// start the protocol
go s.startProtocol(p)
case cmdAddr:
addrList := msg.Payload.(*payload.AddressList)
err = s.handleAddrCmd(addrList, p)
case cmdGetAddr:
err = s.handleGetaddrCmd(msg, p)
case cmdInv:
inv := msg.Payload.(*payload.Inventory)
err = s.handleInvCmd(inv, p)
case cmdBlock:
block := msg.Payload.(*core.Block)
err = s.handleBlockCmd(block, p)
case cmdConsensus:
case cmdTX:
case cmdVerack:
// If we receive a verack here we disconnect. We already handled the verack
// when we sended our version.
err = errors.New("verack already received")
case cmdGetHeaders:
case cmdGetBlocks:
case cmdGetData:
case cmdHeaders:
headers := msg.Payload.(*payload.Headers)
err = s.handleHeadersCmd(headers, p)
default:
// This command is unknown by the server.
err = fmt.Errorf("unknown command received %v", msg.Command)
break
}
// catch all errors here and disconnect.
if err != nil {
s.logger.Printf("processing message failed: %s", err)
break
}
}
// Disconnect the peer when breaked out of the loop.
p.disconnect()
}
type sendTuple struct {
msg *Message
err chan error
}
// TCPPeer represents a remote node, backed by TCP transport.
type TCPPeer struct {
s *Server
// nonce (id) of the peer.
nonce uint32
// underlying TCP connection
conn net.Conn
// host and port information about this peer.
endpoint util.Endpoint
// channel to coordinate messages writen back to the connection.
send chan sendTuple
// channel to receive from underlying connection.
receive chan *Message
// the version sended out by the peer when connected.
pVersion *payload.Version
}
// NewTCPPeer returns a pointer to a TCP Peer.
func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer {
e, _ := util.EndpointFromString(conn.RemoteAddr().String())
return &TCPPeer{
conn: conn,
send: make(chan sendTuple),
receive: make(chan *Message),
endpoint: e,
s: s,
}
}
// Send needed to implement the network.Peer interface
// and provide the functionality to send a message to
// the current peer.
func (p *TCPPeer) Send(msg *Message) error {
t := sendTuple{
msg: msg,
err: make(chan error),
}
p.send <- t
return <-t.err
}
func (p *TCPPeer) version() *payload.Version {
return p.pVersion
}
// id implements the peer interface
func (p *TCPPeer) id() uint32 {
return p.nonce
}
// endpoint implements the peer interface
func (p *TCPPeer) addr() util.Endpoint {
return p.endpoint
}
// disconnect disconnects the peer, cleaning up all its resources.
// 3 goroutines needs to be cleanup (writeLoop, handleConnection and handleMessage)
func (p *TCPPeer) disconnect() {
select {
case <-p.send:
case <-p.receive:
default:
close(p.send)
close(p.receive)
p.s.unregister <- p
p.conn.Close()
}
}
// writeLoop writes messages to the underlying TCP connection.
// A goroutine writeLoop is started for each connection.
// There should be at most one writer to a connection executing
// all writes from this goroutine.
func (p *TCPPeer) writeLoop() {
// clean up the connection.
defer func() {
p.disconnect()
}()
// resuse this buffer
buf := new(bytes.Buffer)
for {
t := <-p.send
if t.msg == nil {
break // send probably closed.
}
// p.s.logger.Printf("OUT :: %s :: %+v", t.msg.commandType(), t.msg.Payload)
if err := t.msg.encode(buf); err != nil {
t.err <- err
}
_, err := p.conn.Write(buf.Bytes())
t.err <- err
buf.Reset()
}
}

134
pkg/network/tcp_peer.go Normal file
View file

@ -0,0 +1,134 @@
package network
import (
"bytes"
"net"
"os"
"time"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
)
// TCPPeer represents a connected remote node in the
// network over TCP.
type TCPPeer struct {
// The endpoint of the peer.
endpoint util.Endpoint
// underlying connection.
conn net.Conn
// The version the peer declared when connecting.
version *payload.Version
// connectedAt is the timestamp the peer connected to
// the network.
connectedAt time.Time
// handleProto is the handler that will handle the
// incoming message along with its peer.
handleProto protoHandleFunc
// Done is used to broadcast this peer has stopped running
// and should be removed as reference.
done chan struct{}
send chan *Message
logger log.Logger
}
// NewTCPPeer creates a new peer from a TCP connection.
func NewTCPPeer(conn net.Conn, fun protoHandleFunc) *TCPPeer {
e := util.NewEndpoint(conn.RemoteAddr().String())
logger := log.NewLogfmtLogger(os.Stderr)
logger = log.With(logger, "component", "peer", "endpoint", e)
return &TCPPeer{
endpoint: e,
conn: conn,
done: make(chan struct{}),
send: make(chan *Message),
logger: logger,
connectedAt: time.Now().UTC(),
handleProto: fun,
}
}
// Version implements the Peer interface.
func (p *TCPPeer) Version() *payload.Version {
return p.version
}
// Endpoint implements the Peer interface.
func (p *TCPPeer) Endpoint() util.Endpoint {
return p.endpoint
}
// Send implements the Peer interface.
func (p *TCPPeer) Send(msg *Message) {
p.send <- msg
}
// Done implemnets the Peer interface.
func (p *TCPPeer) Done() chan struct{} {
return p.done
}
func (p *TCPPeer) run() error {
errCh := make(chan error, 1)
go p.readLoop(errCh)
go p.writeLoop(errCh)
err := <-errCh
p.logger.Log("err", err)
p.cleanup()
return err
}
func (p *TCPPeer) readLoop(errCh chan error) {
for {
msg := &Message{}
if err := msg.decode(p.conn); err != nil {
errCh <- err
break
}
p.handleMessage(msg)
}
}
func (p *TCPPeer) writeLoop(errCh chan error) {
buf := new(bytes.Buffer)
for {
msg := <-p.send
if err := msg.encode(buf); err != nil {
errCh <- err
break
}
if _, err := p.conn.Write(buf.Bytes()); err != nil {
errCh <- err
break
}
buf.Reset()
}
}
func (p *TCPPeer) cleanup() {
p.conn.Close()
close(p.send)
p.done <- struct{}{}
}
func (p *TCPPeer) handleMessage(msg *Message) {
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
p.version = version
p.handleProto(msg, p)
default:
p.handleProto(msg, p)
}
}

View file

@ -12,14 +12,11 @@ type Endpoint struct {
Port uint16 Port uint16
} }
// EndpointFromString returns an Endpoint from the given string. // NewEndpoint creates an Endpoint from the given string.
// For now this only handles the most simple hostport form. func NewEndpoint(s string) (e Endpoint) {
// e.g. 127.0.0.1:3000
// This should be enough to work with for now.
func EndpointFromString(s string) (Endpoint, error) {
hostPort := strings.Split(s, ":") hostPort := strings.Split(s, ":")
if len(hostPort) != 2 { if len(hostPort) != 2 {
return Endpoint{}, fmt.Errorf("invalid address string: %s", s) return e
} }
host := hostPort[0] host := hostPort[0]
port := hostPort[1] port := hostPort[1]
@ -36,7 +33,7 @@ func EndpointFromString(s string) (Endpoint, error) {
p, _ := strconv.Atoi(port) p, _ := strconv.Atoi(port)
return Endpoint{buf, uint16(p)}, nil return Endpoint{buf, uint16(p)}
} }
// Network implements the net.Addr interface. // Network implements the net.Addr interface.

View file

@ -34,7 +34,7 @@ func Uint256DecodeBytes(b []byte) (u Uint256, err error) {
return u, nil return u, nil
} }
// ToSlice returns a byte slice representation of u. // Bytes returns a byte slice representation of u.
func (u Uint256) Bytes() []byte { func (u Uint256) Bytes() []byte {
b := make([]byte, uint256Size) b := make([]byte, uint256Size)
for i := 0; i < uint256Size; i++ { for i := 0; i < uint256Size; i++ {

View file

@ -8,7 +8,7 @@ import (
) )
const ( const (
// The WIF network version used to decode and encode WIF keys. // WIFVersion is the version used to decode and encode WIF keys.
WIFVersion = 0x80 WIFVersion = 0x80
) )