mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-12-23 03:41:34 +00:00
Refactor of the Go node (#44)
* added headersOp for safely processing headers * Better handling of protocol messages. * housekeeping + cleanup tests * Added more blockchain logic + unit tests * fixed unreachable error. * added structured logging for all (node) components. * added relay flag + bumped version
This commit is contained in:
parent
b2a5e34aac
commit
4023661cf1
43 changed files with 1497 additions and 1265 deletions
26
Gopkg.lock
generated
26
Gopkg.lock
generated
|
@ -19,12 +19,36 @@
|
|||
revision = "346938d642f2ec3594ed81d874461961cd0faa76"
|
||||
version = "v1.1.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/go-kit/kit"
|
||||
packages = ["log"]
|
||||
revision = "4dc7be5d2d12881735283bcab7352178e190fc71"
|
||||
version = "v0.6.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/go-logfmt/logfmt"
|
||||
packages = ["."]
|
||||
revision = "390ab7935ee28ec6b286364bba9b4dd6410cb3d5"
|
||||
version = "v0.3.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/go-stack/stack"
|
||||
packages = ["."]
|
||||
revision = "259ab82a6cad3992b4e21ff5cac294ccb06474bc"
|
||||
version = "v1.7.0"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/golang/snappy"
|
||||
packages = ["."]
|
||||
revision = "553a641470496b2327abcac10b36396bd98e45c9"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/kr/logfmt"
|
||||
packages = ["."]
|
||||
revision = "b84e30acd515aadc4b783ad4ff83aff3299bdfe0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/pmezard/go-difflib"
|
||||
packages = ["difflib"]
|
||||
|
@ -98,6 +122,6 @@
|
|||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
inputs-digest = "069a738aa1487766b26f9efb8103d2ce0526d43c83049cb5b792f0edf91568de"
|
||||
inputs-digest = "53597073e919ad7bf52895a19f8b8526d12d666862fb1d36b4a9756e0499da5a"
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
|
|
|
@ -51,3 +51,7 @@
|
|||
[[constraint]]
|
||||
name = "github.com/stretchr/testify"
|
||||
version = "1.2.1"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/go-kit/kit"
|
||||
version = "0.6.0"
|
||||
|
|
2
Makefile
2
Makefile
|
@ -19,7 +19,7 @@ push-tag:
|
|||
git push origin ${BRANCH} --tags
|
||||
|
||||
run: build
|
||||
./bin/neo-go node -seed ${SEEDS} -tcp ${PORT}
|
||||
./bin/neo-go node -seed ${SEEDS} -tcp ${PORT} --relay true
|
||||
|
||||
test:
|
||||
@go test ./... -cover
|
||||
|
|
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
|||
0.25.0
|
||||
0.26.0
|
||||
|
|
|
@ -16,6 +16,7 @@ func NewCommand() cli.Command {
|
|||
Flags: []cli.Flag{
|
||||
cli.IntFlag{Name: "tcp"},
|
||||
cli.IntFlag{Name: "rpc"},
|
||||
cli.BoolFlag{Name: "relay, r"},
|
||||
cli.StringFlag{Name: "seed"},
|
||||
cli.BoolFlag{Name: "privnet, p"},
|
||||
cli.BoolFlag{Name: "mainnet, m"},
|
||||
|
@ -25,12 +26,6 @@ func NewCommand() cli.Command {
|
|||
}
|
||||
|
||||
func startServer(ctx *cli.Context) error {
|
||||
opts := network.StartOpts{
|
||||
Seeds: parseSeeds(ctx.String("seed")),
|
||||
TCP: ctx.Int("tcp"),
|
||||
RPC: ctx.Int("rpc"),
|
||||
}
|
||||
|
||||
net := network.ModePrivNet
|
||||
if ctx.Bool("testnet") {
|
||||
net = network.ModeTestNet
|
||||
|
@ -39,8 +34,16 @@ func startServer(ctx *cli.Context) error {
|
|||
net = network.ModeMainNet
|
||||
}
|
||||
|
||||
s := network.NewServer(net)
|
||||
s.Start(opts)
|
||||
cfg := network.Config{
|
||||
UserAgent: "/NEO-GO:0.26.0/",
|
||||
ListenTCP: uint16(ctx.Int("tcp")),
|
||||
Seeds: parseSeeds(ctx.String("seed")),
|
||||
Net: net,
|
||||
Relay: ctx.Bool("relay"),
|
||||
}
|
||||
|
||||
s := network.NewServer(cfg)
|
||||
s.Start()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
15
main.go
Normal file
15
main.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
log "github.com/go-kit/kit/log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
logger := log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr))
|
||||
logger.Log("hello", true)
|
||||
|
||||
logger = log.With(logger, "module", "node")
|
||||
logger.Log("foo", true)
|
||||
}
|
|
@ -33,6 +33,9 @@ type BlockBase struct {
|
|||
_ uint8 // padding
|
||||
// Script used to validate the block
|
||||
Script *transaction.Witness
|
||||
|
||||
// hash of this block, created when binary encoded.
|
||||
hash util.Uint256
|
||||
}
|
||||
|
||||
// DecodeBinary implements the payload interface.
|
||||
|
@ -68,19 +71,35 @@ func (b *BlockBase) DecodeBinary(r io.Reader) error {
|
|||
}
|
||||
|
||||
b.Script = &transaction.Witness{}
|
||||
return b.Script.DecodeBinary(r)
|
||||
if err := b.Script.DecodeBinary(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make the hash of the block here so we dont need to do this
|
||||
// again.
|
||||
hash, err := b.createHash()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.hash = hash
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hash returns the hash of the block.
|
||||
// Hash return the hash of the block.
|
||||
func (b *BlockBase) Hash() util.Uint256 {
|
||||
return b.hash
|
||||
}
|
||||
|
||||
// createHash creates the hash of the block.
|
||||
// When calculating the hash value of the block, instead of calculating the entire block,
|
||||
// only first seven fields in the block head will be calculated, which are
|
||||
// version, PrevBlock, MerkleRoot, timestamp, and height, the nonce, NextMiner.
|
||||
// Since MerkleRoot already contains the hash value of all transactions,
|
||||
// the modification of transaction will influence the hash value of the block.
|
||||
func (b *BlockBase) Hash() (hash util.Uint256, err error) {
|
||||
func (b *BlockBase) createHash() (hash util.Uint256, err error) {
|
||||
buf := new(bytes.Buffer)
|
||||
if err = b.encodeHashableFields(buf); err != nil {
|
||||
return
|
||||
return hash, err
|
||||
}
|
||||
|
||||
// Double hash the encoded fields.
|
||||
|
@ -92,15 +111,25 @@ func (b *BlockBase) Hash() (hash util.Uint256, err error) {
|
|||
// encodeHashableFields will only encode the fields used for hashing.
|
||||
// see Hash() for more information about the fields.
|
||||
func (b *BlockBase) encodeHashableFields(w io.Writer) error {
|
||||
err := binary.Write(w, binary.LittleEndian, &b.Version)
|
||||
err = binary.Write(w, binary.LittleEndian, &b.PrevHash)
|
||||
err = binary.Write(w, binary.LittleEndian, &b.MerkleRoot)
|
||||
err = binary.Write(w, binary.LittleEndian, &b.Timestamp)
|
||||
err = binary.Write(w, binary.LittleEndian, &b.Index)
|
||||
err = binary.Write(w, binary.LittleEndian, &b.ConsensusData)
|
||||
err = binary.Write(w, binary.LittleEndian, &b.NextConsensus)
|
||||
|
||||
return err
|
||||
if err := binary.Write(w, binary.LittleEndian, &b.Version); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, &b.PrevHash); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, &b.MerkleRoot); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, &b.Timestamp); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, &b.Index); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, &b.ConsensusData); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(w, binary.LittleEndian, &b.NextConsensus)
|
||||
}
|
||||
|
||||
// EncodeBinary implements the Payload interface
|
||||
|
@ -108,21 +137,16 @@ func (b *BlockBase) EncodeBinary(w io.Writer) error {
|
|||
if err := b.encodeHashableFields(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// padding
|
||||
if err := binary.Write(w, binary.LittleEndian, uint8(1)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// script
|
||||
return b.Script.EncodeBinary(w)
|
||||
}
|
||||
|
||||
// Header holds the head info of a block
|
||||
type Header struct {
|
||||
BlockBase
|
||||
// fixed to 0
|
||||
_ uint8 // padding
|
||||
_ uint8 // padding fixed to 0
|
||||
}
|
||||
|
||||
// Verify the integrity of the header
|
||||
|
@ -150,15 +174,12 @@ func (h *Header) EncodeBinary(w io.Writer) error {
|
|||
if err := h.BlockBase.EncodeBinary(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// padding
|
||||
return binary.Write(w, binary.LittleEndian, uint8(0))
|
||||
}
|
||||
|
||||
// Block represents one block in the chain.
|
||||
type Block struct {
|
||||
BlockBase
|
||||
// transaction list
|
||||
Transactions []*transaction.Transaction
|
||||
}
|
||||
|
||||
|
@ -205,11 +226,10 @@ func (b *Block) DecodeBinary(r io.Reader) error {
|
|||
lentx := util.ReadVarUint(r)
|
||||
b.Transactions = make([]*transaction.Transaction, lentx)
|
||||
for i := 0; i < int(lentx); i++ {
|
||||
tx := &transaction.Transaction{}
|
||||
if err := tx.DecodeBinary(r); err != nil {
|
||||
b.Transactions[i] = &transaction.Transaction{}
|
||||
if err := b.Transactions[i].DecodeBinary(r); err != nil {
|
||||
return err
|
||||
}
|
||||
b.Transactions[i] = tx
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/CityOfZion/neo-go/pkg/core/transaction"
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecodeBlock(t *testing.T) {
|
||||
|
@ -29,63 +30,24 @@ func TestDecodeBlock(t *testing.T) {
|
|||
if err := block.DecodeBinary(bytes.NewReader(rawBlockBytes)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if block.Index != uint32(rawBlockIndex) {
|
||||
t.Fatalf("expected the index to the block to be %d got %d", rawBlockIndex, block.Index)
|
||||
}
|
||||
if block.Timestamp != uint32(rawBlockTimestamp) {
|
||||
t.Fatalf("expected timestamp to be %d got %d", rawBlockTimestamp, block.Timestamp)
|
||||
}
|
||||
if block.ConsensusData != uint64(rawBlockConsensusData) {
|
||||
t.Fatalf("expected consensus data to be %d got %d", rawBlockConsensusData, block.ConsensusData)
|
||||
}
|
||||
if block.PrevHash.String() != rawBlockPrevHash {
|
||||
t.Fatalf("expected prev block hash to be %s got %s", rawBlockPrevHash, block.PrevHash)
|
||||
}
|
||||
hash, err := block.Hash()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if hash.String() != rawBlockHash {
|
||||
t.Fatalf("expected hash of the block to be %s got %s", rawBlockHash, hash)
|
||||
}
|
||||
}
|
||||
|
||||
func newBlockBase() BlockBase {
|
||||
return BlockBase{
|
||||
Version: 0,
|
||||
PrevHash: sha256.Sum256([]byte("a")),
|
||||
MerkleRoot: sha256.Sum256([]byte("b")),
|
||||
Timestamp: 999,
|
||||
Index: 1,
|
||||
ConsensusData: 1111,
|
||||
NextConsensus: util.Uint160{},
|
||||
Script: &transaction.Witness{
|
||||
VerificationScript: []byte{0x0},
|
||||
InvocationScript: []byte{0x1},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, uint32(rawBlockIndex), block.Index)
|
||||
assert.Equal(t, uint32(rawBlockTimestamp), block.Timestamp)
|
||||
assert.Equal(t, uint64(rawBlockConsensusData), block.ConsensusData)
|
||||
assert.Equal(t, rawBlockPrevHash, block.PrevHash.String())
|
||||
assert.Equal(t, rawBlockHash, block.Hash().String())
|
||||
}
|
||||
|
||||
func TestHashBlockEqualsHashHeader(t *testing.T) {
|
||||
base := newBlockBase()
|
||||
b := &Block{BlockBase: base}
|
||||
head := &Header{BlockBase: base}
|
||||
|
||||
bhash, _ := b.Hash()
|
||||
headhash, _ := head.Hash()
|
||||
if bhash != headhash {
|
||||
t.Fatalf("expected both hashes to be equal %s and %s", bhash, headhash)
|
||||
}
|
||||
block := newBlock(0)
|
||||
assert.Equal(t, block.Hash(), block.Header().Hash())
|
||||
}
|
||||
|
||||
func TestBlockVerify(t *testing.T) {
|
||||
block := &Block{
|
||||
BlockBase: newBlockBase(),
|
||||
Transactions: []*transaction.Transaction{
|
||||
{Type: transaction.MinerType},
|
||||
{Type: transaction.IssueType},
|
||||
},
|
||||
}
|
||||
block := newBlock(
|
||||
0,
|
||||
newTX(transaction.MinerType),
|
||||
newTX(transaction.IssueType),
|
||||
)
|
||||
|
||||
if !block.Verify(false) {
|
||||
t.Fatal("block should be verified")
|
||||
|
@ -109,3 +71,19 @@ func TestBlockVerify(t *testing.T) {
|
|||
t.Fatal("block should not by verified")
|
||||
}
|
||||
}
|
||||
|
||||
func newBlockBase() BlockBase {
|
||||
return BlockBase{
|
||||
Version: 0,
|
||||
PrevHash: sha256.Sum256([]byte("a")),
|
||||
MerkleRoot: sha256.Sum256([]byte("b")),
|
||||
Timestamp: 999,
|
||||
Index: 1,
|
||||
ConsensusData: 1111,
|
||||
NextConsensus: util.Uint160{},
|
||||
Script: &transaction.Witness{
|
||||
VerificationScript: []byte{0x0},
|
||||
InvocationScript: []byte{0x1},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,17 +3,19 @@ package core
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"sync"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
log "github.com/go-kit/kit/log"
|
||||
)
|
||||
|
||||
// tuning parameters
|
||||
const (
|
||||
secondsPerBlock = 15
|
||||
writeHdrBatchCnt = 2000
|
||||
headerBatchCount = 2000
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -22,188 +24,217 @@ var (
|
|||
|
||||
// Blockchain holds the chain.
|
||||
type Blockchain struct {
|
||||
logger *log.Logger
|
||||
logger log.Logger
|
||||
|
||||
// Any object that satisfies the BlockchainStorer interface.
|
||||
Store
|
||||
|
||||
// current index of the heighest block
|
||||
currentBlockHeight uint32
|
||||
// Current index/height of the highest block.
|
||||
// Read access should always be called by BlockHeight().
|
||||
// Writes access should only happen in persist().
|
||||
blockHeight uint32
|
||||
|
||||
// number of headers stored
|
||||
// Number of headers stored.
|
||||
storedHeaderCount uint32
|
||||
|
||||
mtx sync.RWMutex
|
||||
blockCache *Cache
|
||||
|
||||
// index of headers hashes
|
||||
headerIndex []util.Uint256
|
||||
startHash util.Uint256
|
||||
|
||||
// Only for operating on the headerList.
|
||||
headersOp chan headersOpFunc
|
||||
headersOpDone chan struct{}
|
||||
}
|
||||
|
||||
// NewBlockchain returns a pointer to a Blockchain.
|
||||
func NewBlockchain(s Store, l *log.Logger, startHash util.Uint256) *Blockchain {
|
||||
type headersOpFunc func(headerList *HeaderHashList)
|
||||
|
||||
// NewBlockchain creates a new Blockchain object.
|
||||
func NewBlockchain(s Store, startHash util.Uint256) *Blockchain {
|
||||
logger := log.NewLogfmtLogger(os.Stderr)
|
||||
logger = log.With(logger, "component", "blockchain")
|
||||
|
||||
bc := &Blockchain{
|
||||
logger: l,
|
||||
Store: s,
|
||||
logger: logger,
|
||||
Store: s,
|
||||
headersOp: make(chan headersOpFunc),
|
||||
headersOpDone: make(chan struct{}),
|
||||
startHash: startHash,
|
||||
blockCache: NewCache(),
|
||||
}
|
||||
|
||||
// Starthash is 0, so we will create the genesis block.
|
||||
if startHash.Equals(util.Uint256{}) {
|
||||
bc.logger.Fatal("genesis block not yet implemented")
|
||||
}
|
||||
|
||||
bc.headerIndex = []util.Uint256{startHash}
|
||||
go bc.run()
|
||||
bc.init()
|
||||
|
||||
return bc
|
||||
}
|
||||
|
||||
// genesisBlock creates the genesis block for the chain.
|
||||
// hash of the genesis block:
|
||||
// d42561e3d30e15be6400b6df2f328e02d2bf6354c41dce433bc57687c82144bf
|
||||
func (bc *Blockchain) genesisBlock() *Block {
|
||||
timestamp := uint32(time.Date(2016, 7, 15, 15, 8, 21, 0, time.UTC).Unix())
|
||||
func (bc *Blockchain) init() {
|
||||
// for the initial header, for now
|
||||
bc.storedHeaderCount = 1
|
||||
}
|
||||
|
||||
// TODO: for testing I will hardcode the merkleroot.
|
||||
// This let's me focus on the bringing all the puzzle pieces
|
||||
// togheter much faster.
|
||||
// For more information about the genesis block:
|
||||
// https://neotracker.io/block/height/0
|
||||
mr, _ := util.Uint256DecodeString("803ff4abe3ea6533bcc0be574efa02f83ae8fdc651c879056b0d9be336c01bf4")
|
||||
|
||||
return &Block{
|
||||
BlockBase: BlockBase{
|
||||
Version: 0,
|
||||
PrevHash: util.Uint256{},
|
||||
MerkleRoot: mr,
|
||||
Timestamp: timestamp,
|
||||
Index: 0,
|
||||
ConsensusData: 2083236893, // nioctib ^^
|
||||
NextConsensus: util.Uint160{}, // todo
|
||||
},
|
||||
func (bc *Blockchain) run() {
|
||||
headerList := NewHeaderHashList(bc.startHash)
|
||||
for {
|
||||
select {
|
||||
case op := <-bc.headersOp:
|
||||
op(headerList)
|
||||
bc.headersOpDone <- struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddBlock (to be continued after headers is finished..)
|
||||
func (bc *Blockchain) AddBlock(block *Block) error {
|
||||
// TODO: caching
|
||||
headerLen := len(bc.headerIndex)
|
||||
if !bc.blockCache.Has(block.Hash()) {
|
||||
bc.blockCache.Add(block.Hash(), block)
|
||||
}
|
||||
|
||||
headerLen := int(bc.HeaderHeight() + 1)
|
||||
if int(block.Index-1) >= headerLen {
|
||||
return nil
|
||||
}
|
||||
|
||||
if int(block.Index) == headerLen {
|
||||
// todo: if (VerifyBlocks && !block.Verify()) return false;
|
||||
}
|
||||
|
||||
if int(block.Index) < headerLen {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return bc.AddHeaders(block.Header())
|
||||
}
|
||||
|
||||
func (bc *Blockchain) addHeader(header *Header) error {
|
||||
return bc.AddHeaders(header)
|
||||
func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {
|
||||
var (
|
||||
start = time.Now()
|
||||
batch = Batch{}
|
||||
)
|
||||
|
||||
bc.headersOp <- func(headerList *HeaderHashList) {
|
||||
for _, h := range headers {
|
||||
if int(h.Index-1) >= headerList.Len() {
|
||||
err = fmt.Errorf(
|
||||
"height of block higher then current header height %d > %d\n",
|
||||
h.Index, headerList.Len(),
|
||||
)
|
||||
return
|
||||
}
|
||||
if int(h.Index) < headerList.Len() {
|
||||
continue
|
||||
}
|
||||
if !h.Verify() {
|
||||
err = fmt.Errorf("header %v is invalid", h)
|
||||
return
|
||||
}
|
||||
if err = bc.processHeader(h, batch, headerList); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Implement caching strategy.
|
||||
if len(batch) > 0 {
|
||||
if err = bc.writeBatch(batch); err != nil {
|
||||
return
|
||||
}
|
||||
bc.logger.Log(
|
||||
"msg", "done processing headers",
|
||||
"index", headerList.Len()-1,
|
||||
"took", time.Since(start).Seconds(),
|
||||
)
|
||||
}
|
||||
}
|
||||
<-bc.headersOpDone
|
||||
return err
|
||||
}
|
||||
|
||||
// AddHeaders processes the given headers.
|
||||
func (bc *Blockchain) AddHeaders(headers ...*Header) error {
|
||||
start := time.Now()
|
||||
|
||||
bc.mtx.Lock()
|
||||
defer bc.mtx.Unlock()
|
||||
|
||||
batch := Batch{}
|
||||
for _, h := range headers {
|
||||
if int(h.Index-1) >= len(bc.headerIndex) {
|
||||
bc.logger.Printf("height of block higher then header index %d %d\n",
|
||||
h.Index, len(bc.headerIndex))
|
||||
break
|
||||
}
|
||||
if int(h.Index) < len(bc.headerIndex) {
|
||||
continue
|
||||
}
|
||||
if !h.Verify() {
|
||||
bc.logger.Printf("header %v is invalid", h)
|
||||
break
|
||||
}
|
||||
if err := bc.processHeader(h, batch); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Implement caching strategy.
|
||||
if len(batch) > 0 {
|
||||
// Write all batches.
|
||||
if err := bc.writeBatch(batch); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bc.logger.Printf("done processing headers up to index %d took %f Seconds",
|
||||
bc.HeaderHeight(), time.Since(start).Seconds())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processHeader processes 1 header.
|
||||
func (bc *Blockchain) processHeader(h *Header, batch Batch) error {
|
||||
hash, err := h.Hash()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bc.headerIndex = append(bc.headerIndex, hash)
|
||||
|
||||
for int(h.Index)-writeHdrBatchCnt >= int(bc.storedHeaderCount) {
|
||||
// hdrsToWrite = bc.headerIndex[bc.storedHeaderCount : bc.storedHeaderCount+writeHdrBatchCnt]
|
||||
|
||||
// NOTE: from original #c to be implemented:
|
||||
//
|
||||
// w.Write(header_index.Skip((int)stored_header_count).Take(2000).ToArray());
|
||||
// w.Flush();
|
||||
// batch.Put(SliceBuilder.Begin(DataEntryPrefix.IX_HeaderHashList).Add(stored_header_count), ms.ToArray());
|
||||
|
||||
bc.storedHeaderCount += writeHdrBatchCnt
|
||||
}
|
||||
// processHeader processes the given header. Note that this is only thread safe
|
||||
// if executed in headers operation.
|
||||
func (bc *Blockchain) processHeader(h *Header, batch Batch, headerList *HeaderHashList) error {
|
||||
headerList.Add(h.Hash())
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
for int(h.Index)-headerBatchCount >= int(bc.storedHeaderCount) {
|
||||
if err := headerList.Write(buf, int(bc.storedHeaderCount), headerBatchCount); err != nil {
|
||||
return err
|
||||
}
|
||||
key := makeEntryPrefixInt(preIXHeaderHashList, int(bc.storedHeaderCount))
|
||||
batch[&key] = buf.Bytes()
|
||||
bc.storedHeaderCount += headerBatchCount
|
||||
buf.Reset()
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
if err := h.EncodeBinary(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
preBlock := preDataBlock.add(hash.BytesReverse())
|
||||
batch[&preBlock] = buf.Bytes()
|
||||
preHeader := preSYSCurrentHeader.toSlice()
|
||||
batch[&preHeader] = hashAndIndexToBytes(hash, h.Index)
|
||||
key := makeEntryPrefix(preDataBlock, h.Hash().BytesReverse())
|
||||
batch[&key] = buf.Bytes()
|
||||
key = preSYSCurrentHeader.bytes()
|
||||
batch[&key] = hashAndIndexToBytes(h.Hash(), h.Index)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CurrentBlockHash return the lastest hash in the header index.
|
||||
func (bc *Blockchain) CurrentBlockHash() (hash util.Uint256) {
|
||||
if len(bc.headerIndex) == 0 {
|
||||
return
|
||||
}
|
||||
if len(bc.headerIndex) < int(bc.currentBlockHeight) {
|
||||
return
|
||||
}
|
||||
func (bc *Blockchain) persistBlock(block *Block) error {
|
||||
bc.blockHeight = block.Index
|
||||
return nil
|
||||
}
|
||||
|
||||
return bc.headerIndex[bc.currentBlockHeight]
|
||||
func (bc *Blockchain) persist() (err error) {
|
||||
var (
|
||||
persisted = 0
|
||||
lenCache = bc.blockCache.Len()
|
||||
)
|
||||
|
||||
for lenCache > persisted {
|
||||
if bc.HeaderHeight()+1 <= bc.BlockHeight() {
|
||||
break
|
||||
}
|
||||
bc.headersOp <- func(headerList *HeaderHashList) {
|
||||
hash := headerList.Get(int(bc.BlockHeight() + 1))
|
||||
if block, ok := bc.blockCache.GetBlock(hash); ok {
|
||||
if err = bc.persistBlock(block); err != nil {
|
||||
return
|
||||
}
|
||||
bc.blockCache.Delete(hash)
|
||||
persisted++
|
||||
} else {
|
||||
bc.logger.Log(
|
||||
"msg", "block not found in cache",
|
||||
"hash", block.Hash(),
|
||||
)
|
||||
}
|
||||
}
|
||||
<-bc.headersOpDone
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CurrentBlockHash returns the heighest processed block hash.
|
||||
func (bc *Blockchain) CurrentBlockHash() (hash util.Uint256) {
|
||||
bc.headersOp <- func(headerList *HeaderHashList) {
|
||||
hash = headerList.Get(int(bc.BlockHeight()))
|
||||
}
|
||||
<-bc.headersOpDone
|
||||
return
|
||||
}
|
||||
|
||||
// CurrentHeaderHash returns the hash of the latest known header.
|
||||
func (bc *Blockchain) CurrentHeaderHash() (hash util.Uint256) {
|
||||
return bc.headerIndex[len(bc.headerIndex)-1]
|
||||
bc.headersOp <- func(headerList *HeaderHashList) {
|
||||
hash = headerList.Last()
|
||||
}
|
||||
<-bc.headersOpDone
|
||||
return
|
||||
}
|
||||
|
||||
// BlockHeight return the height/index of the latest block this node has.
|
||||
// BlockHeight returns the height/index of the highest block.
|
||||
func (bc *Blockchain) BlockHeight() uint32 {
|
||||
return bc.currentBlockHeight
|
||||
return atomic.LoadUint32(&bc.blockHeight)
|
||||
}
|
||||
|
||||
// HeaderHeight returns the current index of the headers.
|
||||
func (bc *Blockchain) HeaderHeight() uint32 {
|
||||
return uint32(len(bc.headerIndex)) - 1
|
||||
// HeaderHeight returns the index/height of the highest header.
|
||||
func (bc *Blockchain) HeaderHeight() (n uint32) {
|
||||
bc.headersOp <- func(headerList *HeaderHashList) {
|
||||
n = uint32(headerList.Len() - 1)
|
||||
}
|
||||
<-bc.headersOpDone
|
||||
return
|
||||
}
|
||||
|
||||
func hashAndIndexToBytes(h util.Uint256, index uint32) []byte {
|
||||
|
|
|
@ -1,48 +1,70 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/core/transaction"
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewBlockchain(t *testing.T) {
|
||||
startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099")
|
||||
bc := NewBlockchain(nil, nil, startHash)
|
||||
bc := NewBlockchain(nil, startHash)
|
||||
|
||||
want := uint32(0)
|
||||
if have := bc.BlockHeight(); want != have {
|
||||
t.Fatalf("expected %d got %d", want, have)
|
||||
}
|
||||
if have := bc.HeaderHeight(); want != have {
|
||||
t.Fatalf("expected %d got %d", want, have)
|
||||
}
|
||||
if have := bc.storedHeaderCount; want != have {
|
||||
t.Fatalf("expected %d got %d", want, have)
|
||||
}
|
||||
if !bc.CurrentBlockHash().Equals(startHash) {
|
||||
t.Fatalf("expected current block hash to be %d got %s", startHash, bc.CurrentBlockHash())
|
||||
}
|
||||
assert.Equal(t, uint32(0), bc.BlockHeight())
|
||||
assert.Equal(t, uint32(0), bc.HeaderHeight())
|
||||
assert.Equal(t, uint32(1), bc.storedHeaderCount)
|
||||
assert.Equal(t, startHash, bc.startHash)
|
||||
}
|
||||
|
||||
func TestAddHeaders(t *testing.T) {
|
||||
startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099")
|
||||
bc := NewBlockchain(NewMemoryStore(), log.New(os.Stdout, "", 0), startHash)
|
||||
|
||||
h1 := &Header{BlockBase: BlockBase{Version: 0, Index: 1, Script: &transaction.Witness{}}}
|
||||
h2 := &Header{BlockBase: BlockBase{Version: 0, Index: 2, Script: &transaction.Witness{}}}
|
||||
h3 := &Header{BlockBase: BlockBase{Version: 0, Index: 3, Script: &transaction.Witness{}}}
|
||||
bc := newTestBC()
|
||||
h1 := newBlock(1).Header()
|
||||
h2 := newBlock(2).Header()
|
||||
h3 := newBlock(3).Header()
|
||||
|
||||
if err := bc.AddHeaders(h1, h2, h3); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want, have := h3.Index, bc.HeaderHeight(); want != have {
|
||||
t.Fatalf("expected header height of %d got %d", want, have)
|
||||
}
|
||||
if want, have := uint32(0), bc.storedHeaderCount; want != have {
|
||||
t.Fatalf("expected stored header count to be %d got %d", want, have)
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, bc.blockCache.Len())
|
||||
assert.Equal(t, h3.Index, bc.HeaderHeight())
|
||||
assert.Equal(t, uint32(1), bc.storedHeaderCount)
|
||||
assert.Equal(t, uint32(0), bc.BlockHeight())
|
||||
assert.Equal(t, h3.Hash(), bc.CurrentHeaderHash())
|
||||
}
|
||||
|
||||
func TestAddBlock(t *testing.T) {
|
||||
bc := newTestBC()
|
||||
blocks := []*Block{
|
||||
newBlock(1),
|
||||
newBlock(2),
|
||||
newBlock(3),
|
||||
}
|
||||
|
||||
for i := 0; i < len(blocks); i++ {
|
||||
if err := bc.AddBlock(blocks[i]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
lastBlock := blocks[len(blocks)-1]
|
||||
assert.Equal(t, 3, bc.blockCache.Len())
|
||||
assert.Equal(t, lastBlock.Index, bc.HeaderHeight())
|
||||
assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash())
|
||||
assert.Equal(t, uint32(1), bc.storedHeaderCount)
|
||||
|
||||
if err := bc.persist(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, lastBlock.Index, bc.BlockHeight())
|
||||
assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash())
|
||||
assert.Equal(t, 0, bc.blockCache.Len())
|
||||
}
|
||||
|
||||
func newTestBC() *Blockchain {
|
||||
startHash, _ := util.Uint256DecodeString("a")
|
||||
bc := NewBlockchain(NewMemoryStore(), startHash)
|
||||
return bc
|
||||
}
|
||||
|
|
73
pkg/core/cache.go
Normal file
73
pkg/core/cache.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// Cache is data structure with fixed type key of Uint256, but has a
|
||||
// generic value. Used for block and header cash types.
|
||||
type Cache struct {
|
||||
lock sync.RWMutex
|
||||
m map[util.Uint256]interface{}
|
||||
}
|
||||
|
||||
// NewCache returns a ready to use Cache object.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
m: make(map[util.Uint256]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// GetBlock will return a Block type from the cache.
|
||||
func (c *Cache) GetBlock(h util.Uint256) (block *Block, ok bool) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
return c.getBlock(h)
|
||||
}
|
||||
|
||||
func (c *Cache) getBlock(h util.Uint256) (block *Block, ok bool) {
|
||||
if v, b := c.m[h]; b {
|
||||
block, ok = v.(*Block)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Add adds the given hash along with its value to the cache.
|
||||
func (c *Cache) Add(h util.Uint256, v interface{}) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
c.add(h, v)
|
||||
}
|
||||
|
||||
func (c *Cache) add(h util.Uint256, v interface{}) {
|
||||
c.m[h] = v
|
||||
}
|
||||
|
||||
func (c *Cache) has(h util.Uint256) bool {
|
||||
_, ok := c.m[h]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Hash returns whether the cach contains the given hash.
|
||||
func (c *Cache) Has(h util.Uint256) bool {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
return c.has(h)
|
||||
}
|
||||
|
||||
// Len return the number of items present in the cache.
|
||||
func (c *Cache) Len() int {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
return len(c.m)
|
||||
}
|
||||
|
||||
// Delete removes the item out of the cache.
|
||||
func (c *Cache) Delete(h util.Uint256) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
delete(c.m, h)
|
||||
}
|
66
pkg/core/header_hash_list.go
Normal file
66
pkg/core/header_hash_list.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// A HeaderHashList represents a list of header hashes.
|
||||
type HeaderHashList struct {
|
||||
hashes []util.Uint256
|
||||
}
|
||||
|
||||
// NewHeaderHashList return a new pointer to a HeaderHashList.
|
||||
func NewHeaderHashList(hashes ...util.Uint256) *HeaderHashList {
|
||||
return &HeaderHashList{
|
||||
hashes: hashes,
|
||||
}
|
||||
}
|
||||
|
||||
// Add appends the given hash to the list of hashes.
|
||||
func (l *HeaderHashList) Add(h util.Uint256) {
|
||||
l.hashes = append(l.hashes, h)
|
||||
}
|
||||
|
||||
// Len return the length of the underlying hashes slice.
|
||||
func (l *HeaderHashList) Len() int {
|
||||
return len(l.hashes)
|
||||
}
|
||||
|
||||
// Get returns the hash by the given index.
|
||||
func (l *HeaderHashList) Get(i int) util.Uint256 {
|
||||
if l.Len() < i {
|
||||
return util.Uint256{}
|
||||
}
|
||||
return l.hashes[i]
|
||||
}
|
||||
|
||||
// Last return the last hash in the HeaderHashList.
|
||||
func (l *HeaderHashList) Last() util.Uint256 {
|
||||
return l.hashes[l.Len()-1]
|
||||
}
|
||||
|
||||
// Slice return a subslice of the underlying hashes.
|
||||
// Subsliced from start to end.
|
||||
// Example:
|
||||
// headers := headerList.Slice(0, 2000)
|
||||
func (l *HeaderHashList) Slice(start, end int) []util.Uint256 {
|
||||
return l.hashes[start:end]
|
||||
}
|
||||
|
||||
// WriteTo will write n underlying hashes to the given io.Writer
|
||||
// starting from start.
|
||||
func (l *HeaderHashList) Write(w io.Writer, start, n int) error {
|
||||
if err := util.WriteVarUint(w, uint64(n)); err != nil {
|
||||
return err
|
||||
}
|
||||
hashes := l.Slice(start, start+n)
|
||||
for _, hash := range hashes {
|
||||
if err := binary.Write(w, binary.LittleEndian, hash); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
40
pkg/core/helper_test.go
Normal file
40
pkg/core/helper_test.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"time"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/core/transaction"
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
func newBlock(index uint32, txs ...*transaction.Transaction) *Block {
|
||||
b := &Block{
|
||||
BlockBase: BlockBase{
|
||||
Version: 0,
|
||||
PrevHash: sha256.Sum256([]byte("a")),
|
||||
MerkleRoot: sha256.Sum256([]byte("b")),
|
||||
Timestamp: uint32(time.Now().UTC().Unix()),
|
||||
Index: index,
|
||||
ConsensusData: 1111,
|
||||
NextConsensus: util.Uint160{},
|
||||
Script: &transaction.Witness{
|
||||
VerificationScript: []byte{0x0},
|
||||
InvocationScript: []byte{0x1},
|
||||
},
|
||||
},
|
||||
Transactions: txs,
|
||||
}
|
||||
hash, err := b.createHash()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b.hash = hash
|
||||
return b
|
||||
}
|
||||
|
||||
func newTX(t transaction.TXType) *transaction.Transaction {
|
||||
return &transaction.Transaction{
|
||||
Type: t,
|
||||
}
|
||||
}
|
|
@ -21,6 +21,5 @@ func (s *LevelDBStore) writeBatch(batch Batch) error {
|
|||
for k, v := range batch {
|
||||
b.Put(*k, v)
|
||||
}
|
||||
|
||||
return s.db.Write(b, nil)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package core
|
||||
|
||||
// MemoryStore is an in memory implementation of a BlockChainStorer.
|
||||
// MemoryStore is an in memory implementation of a BlockChainStorer
|
||||
// that should only be used for testing.
|
||||
type MemoryStore struct {
|
||||
}
|
||||
|
||||
|
@ -14,5 +15,10 @@ func (m *MemoryStore) write(key, value []byte) error {
|
|||
}
|
||||
|
||||
func (m *MemoryStore) writeBatch(batch Batch) error {
|
||||
for k, v := range batch {
|
||||
if err := m.write(*k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,17 +1,13 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type dataEntry uint8
|
||||
|
||||
func (e dataEntry) add(b []byte) []byte {
|
||||
dest := make([]byte, len(b)+1)
|
||||
dest[0] = byte(e)
|
||||
for i := 1; i < len(b); i++ {
|
||||
dest[i] = b[i]
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
func (e dataEntry) toSlice() []byte {
|
||||
func (e dataEntry) bytes() []byte {
|
||||
return []byte{byte(e)}
|
||||
}
|
||||
|
||||
|
@ -32,6 +28,21 @@ const (
|
|||
preSYSVersion dataEntry = 0xf0
|
||||
)
|
||||
|
||||
func makeEntryPrefixInt(e dataEntry, n int) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
binary.Write(buf, binary.LittleEndian, n)
|
||||
return makeEntryPrefix(e, buf.Bytes())
|
||||
}
|
||||
|
||||
func makeEntryPrefix(e dataEntry, b []byte) []byte {
|
||||
dest := make([]byte, len(b)+1)
|
||||
dest[0] = byte(e)
|
||||
for i := 1; i < len(b); i++ {
|
||||
dest[i] = b[i]
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
// Store is anything that can persist and retrieve the blockchain.
|
||||
type Store interface {
|
||||
write(k, v []byte) error
|
||||
|
@ -39,5 +50,5 @@ type Store interface {
|
|||
}
|
||||
|
||||
// Batch is a data type used to store data for later batch operations
|
||||
// by any Store.
|
||||
// that can be used by any Store interface implementation.
|
||||
type Batch map[*[]byte][]byte
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
// Transaction is a process recorded in the NEO blockchain.
|
||||
type Transaction struct {
|
||||
// The type of the transaction.
|
||||
Type TransactionType
|
||||
Type TXType
|
||||
|
||||
// The trading version which is currently 0.
|
||||
Version uint8
|
||||
|
|
|
@ -1,26 +1,26 @@
|
|||
package transaction
|
||||
|
||||
// TransactionType is the type of a transaction.
|
||||
type TransactionType uint8
|
||||
// TXType is the type of a transaction.
|
||||
type TXType uint8
|
||||
|
||||
// All processes in NEO system are recorded in transactions.
|
||||
// There are several types of transactions.
|
||||
const (
|
||||
MinerType TransactionType = 0x00
|
||||
IssueType TransactionType = 0x01
|
||||
ClaimType TransactionType = 0x02
|
||||
EnrollmentType TransactionType = 0x20
|
||||
VotingType TransactionType = 0x24
|
||||
RegisterType TransactionType = 0x40
|
||||
ContractType TransactionType = 0x80
|
||||
StateType TransactionType = 0x90
|
||||
AgencyType TransactionType = 0xb0
|
||||
PublishType TransactionType = 0xd0
|
||||
InvocationType TransactionType = 0xd1
|
||||
MinerType TXType = 0x00
|
||||
IssueType TXType = 0x01
|
||||
ClaimType TXType = 0x02
|
||||
EnrollmentType TXType = 0x20
|
||||
VotingType TXType = 0x24
|
||||
RegisterType TXType = 0x40
|
||||
ContractType TXType = 0x80
|
||||
StateType TXType = 0x90
|
||||
AgencyType TXType = 0xb0
|
||||
PublishType TXType = 0xd0
|
||||
InvocationType TXType = 0xd1
|
||||
)
|
||||
|
||||
// String implements the stringer interface.
|
||||
func (t TransactionType) String() string {
|
||||
func (t TXType) String() string {
|
||||
switch t {
|
||||
case MinerType:
|
||||
return "miner transaction"
|
||||
|
|
12
pkg/core/util.go
Normal file
12
pkg/core/util.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package core
|
||||
|
||||
import "github.com/CityOfZion/neo-go/pkg/util"
|
||||
|
||||
// Utilities for quick bootstrapping blockchains. Normally we should
|
||||
// create the genisis block. For now (to speed up development) we will add
|
||||
// The hashes manually.
|
||||
|
||||
func GenesisHashPrivNet() util.Uint256 {
|
||||
hash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099")
|
||||
return hash
|
||||
}
|
|
@ -79,6 +79,7 @@ func Base58Encode(bytes []byte) string {
|
|||
return encoded
|
||||
}
|
||||
|
||||
// Base58CheckDecode decodes the given string.
|
||||
func Base58CheckDecode(s string) (b []byte, err error) {
|
||||
b, err = Base58Decode(s)
|
||||
if err != nil {
|
||||
|
|
|
@ -61,26 +61,28 @@ type Message struct {
|
|||
Payload payload.Payload
|
||||
}
|
||||
|
||||
type commandType string
|
||||
// CommandType represents the type of a message command.
|
||||
type CommandType string
|
||||
|
||||
// valid commands used to send between nodes.
|
||||
const (
|
||||
cmdVersion commandType = "version"
|
||||
cmdVerack = "verack"
|
||||
cmdGetAddr = "getaddr"
|
||||
cmdAddr = "addr"
|
||||
cmdGetHeaders = "getheaders"
|
||||
cmdHeaders = "headers"
|
||||
cmdGetBlocks = "getblocks"
|
||||
cmdInv = "inv"
|
||||
cmdGetData = "getdata"
|
||||
cmdBlock = "block"
|
||||
cmdTX = "tx"
|
||||
cmdConsensus = "consensus"
|
||||
cmdUnknown = "unknown"
|
||||
CMDVersion CommandType = "version"
|
||||
CMDVerack CommandType = "verack"
|
||||
CMDGetAddr CommandType = "getaddr"
|
||||
CMDAddr CommandType = "addr"
|
||||
CMDGetHeaders CommandType = "getheaders"
|
||||
CMDHeaders CommandType = "headers"
|
||||
CMDGetBlocks CommandType = "getblocks"
|
||||
CMDInv CommandType = "inv"
|
||||
CMDGetData CommandType = "getdata"
|
||||
CMDBlock CommandType = "block"
|
||||
CMDTX CommandType = "tx"
|
||||
CMDConsensus CommandType = "consensus"
|
||||
CMDUnknown CommandType = "unknown"
|
||||
)
|
||||
|
||||
func newMessage(magic NetMode, cmd commandType, p payload.Payload) *Message {
|
||||
// NewMessage returns a new message with the given payload.
|
||||
func NewMessage(magic NetMode, cmd CommandType, p payload.Payload) *Message {
|
||||
var (
|
||||
size uint32
|
||||
checksum []byte
|
||||
|
@ -106,36 +108,36 @@ func newMessage(magic NetMode, cmd commandType, p payload.Payload) *Message {
|
|||
}
|
||||
}
|
||||
|
||||
// Converts the 12 byte command slice to a commandType.
|
||||
func (m *Message) commandType() commandType {
|
||||
// CommandType converts the 12 byte command slice to a CommandType.
|
||||
func (m *Message) CommandType() CommandType {
|
||||
cmd := cmdByteArrayToString(m.Command)
|
||||
switch cmd {
|
||||
case "version":
|
||||
return cmdVersion
|
||||
return CMDVersion
|
||||
case "verack":
|
||||
return cmdVerack
|
||||
return CMDVerack
|
||||
case "getaddr":
|
||||
return cmdGetAddr
|
||||
return CMDGetAddr
|
||||
case "addr":
|
||||
return cmdAddr
|
||||
return CMDAddr
|
||||
case "getheaders":
|
||||
return cmdGetHeaders
|
||||
return CMDGetHeaders
|
||||
case "headers":
|
||||
return cmdHeaders
|
||||
return CMDHeaders
|
||||
case "getblocks":
|
||||
return cmdGetBlocks
|
||||
return CMDGetBlocks
|
||||
case "inv":
|
||||
return cmdInv
|
||||
return CMDInv
|
||||
case "getdata":
|
||||
return cmdGetData
|
||||
return CMDGetData
|
||||
case "block":
|
||||
return cmdBlock
|
||||
return CMDBlock
|
||||
case "tx":
|
||||
return cmdTX
|
||||
return CMDTX
|
||||
case "consensus":
|
||||
return cmdConsensus
|
||||
return CMDConsensus
|
||||
default:
|
||||
return cmdUnknown
|
||||
return CMDUnknown
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -185,36 +187,35 @@ func (m *Message) decodePayload(r io.Reader) error {
|
|||
return errChecksumMismatch
|
||||
}
|
||||
|
||||
//r = bytes.NewReader(buf)
|
||||
r = buf
|
||||
var p payload.Payload
|
||||
switch m.commandType() {
|
||||
case cmdVersion:
|
||||
switch m.CommandType() {
|
||||
case CMDVersion:
|
||||
p = &payload.Version{}
|
||||
if err := p.DecodeBinary(r); err != nil {
|
||||
return err
|
||||
}
|
||||
case cmdInv:
|
||||
case CMDInv:
|
||||
p = &payload.Inventory{}
|
||||
if err := p.DecodeBinary(r); err != nil {
|
||||
return err
|
||||
}
|
||||
case cmdAddr:
|
||||
case CMDAddr:
|
||||
p = &payload.AddressList{}
|
||||
if err := p.DecodeBinary(r); err != nil {
|
||||
return err
|
||||
}
|
||||
case cmdBlock:
|
||||
case CMDBlock:
|
||||
p = &core.Block{}
|
||||
if err := p.DecodeBinary(r); err != nil {
|
||||
return err
|
||||
}
|
||||
case cmdGetHeaders:
|
||||
case CMDGetHeaders:
|
||||
p = &payload.GetBlocks{}
|
||||
if err := p.DecodeBinary(r); err != nil {
|
||||
return err
|
||||
}
|
||||
case cmdHeaders:
|
||||
case CMDHeaders:
|
||||
p = &payload.Headers{}
|
||||
if err := p.DecodeBinary(r); err != nil {
|
||||
return err
|
||||
|
@ -242,7 +243,7 @@ func (m *Message) encode(w io.Writer) error {
|
|||
|
||||
// convert a command (string) to a byte slice filled with 0 bytes till
|
||||
// size 12.
|
||||
func cmdToByteArray(cmd commandType) [cmdSize]byte {
|
||||
func cmdToByteArray(cmd CommandType) [cmdSize]byte {
|
||||
cmdLen := len(cmd)
|
||||
if cmdLen > cmdSize {
|
||||
panic("exceeded command max length of size 12")
|
||||
|
|
|
@ -2,39 +2,33 @@ package network
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMessageEncodeDecode(t *testing.T) {
|
||||
m := newMessage(ModeTestNet, cmdVersion, nil)
|
||||
m := NewMessage(ModeTestNet, CMDVersion, nil)
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
if err := m.encode(buf); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if n := len(buf.Bytes()); n < minMessageSize {
|
||||
t.Fatalf("message should be at least %d bytes got %d", minMessageSize, n)
|
||||
}
|
||||
if n := len(buf.Bytes()); n > minMessageSize {
|
||||
t.Fatalf("message without a payload should be exact %d bytes got %d", minMessageSize, n)
|
||||
}
|
||||
assert.Equal(t, len(buf.Bytes()), minMessageSize)
|
||||
|
||||
md := &Message{}
|
||||
if err := md.decode(buf); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(m, md) {
|
||||
t.Errorf("both messages should be equal: %v != %v", m, md)
|
||||
}
|
||||
assert.Equal(t, m, md)
|
||||
}
|
||||
|
||||
func TestMessageEncodeDecodeWithVersion(t *testing.T) {
|
||||
p := payload.NewVersion(12227, 2000, "/neo:2.6.0/", 0, true)
|
||||
m := newMessage(ModeTestNet, cmdVersion, p)
|
||||
var (
|
||||
p = payload.NewVersion(12227, 2000, "/neo:2.6.0/", 0, true)
|
||||
m = NewMessage(ModeTestNet, CMDVersion, p)
|
||||
)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := m.encode(buf); err != nil {
|
||||
|
@ -45,15 +39,14 @@ func TestMessageEncodeDecodeWithVersion(t *testing.T) {
|
|||
if err := mDecode.decode(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(m, mDecode) {
|
||||
t.Fatalf("expected both messages to be equal %v and %v", m, mDecode)
|
||||
}
|
||||
assert.Equal(t, m, mDecode)
|
||||
}
|
||||
|
||||
func TestMessageInvalidChecksum(t *testing.T) {
|
||||
p := payload.NewVersion(1111, 3000, "/NEO:2.6.0/", 0, true)
|
||||
m := newMessage(ModeTestNet, cmdVersion, p)
|
||||
var (
|
||||
p = payload.NewVersion(1111, 3000, "/NEO:2.6.0/", 0, true)
|
||||
m = NewMessage(ModeTestNet, CMDVersion, p)
|
||||
)
|
||||
m.Checksum = 1337
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
|
|
219
pkg/network/node.go
Normal file
219
pkg/network/node.go
Normal file
|
@ -0,0 +1,219 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/core"
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
log "github.com/go-kit/kit/log"
|
||||
)
|
||||
|
||||
const (
|
||||
protoVersion = 0
|
||||
)
|
||||
|
||||
var protoTickInterval = 5 * time.Second
|
||||
|
||||
// Node represents the local node.
|
||||
type Node struct {
|
||||
// Config fields may not be modified while the server is running.
|
||||
Config
|
||||
|
||||
logger log.Logger
|
||||
server *Server
|
||||
services uint64
|
||||
bc *core.Blockchain
|
||||
protoIn chan messageTuple
|
||||
}
|
||||
|
||||
// messageTuple respresents a tuple that holds the message being
|
||||
// send along with its peer.
|
||||
type messageTuple struct {
|
||||
peer Peer
|
||||
msg *Message
|
||||
}
|
||||
|
||||
func newNode(s *Server, cfg Config) *Node {
|
||||
var startHash util.Uint256
|
||||
if cfg.Net == ModePrivNet {
|
||||
startHash = core.GenesisHashPrivNet()
|
||||
}
|
||||
|
||||
bc := core.NewBlockchain(
|
||||
core.NewMemoryStore(),
|
||||
startHash,
|
||||
)
|
||||
|
||||
logger := log.NewLogfmtLogger(os.Stderr)
|
||||
logger = log.With(logger, "component", "node")
|
||||
|
||||
n := &Node{
|
||||
Config: cfg,
|
||||
protoIn: make(chan messageTuple),
|
||||
server: s,
|
||||
bc: bc,
|
||||
logger: logger,
|
||||
}
|
||||
go n.handleMessages()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Node) version() *payload.Version {
|
||||
return payload.NewVersion(n.server.id, n.ListenTCP, n.UserAgent, 1, n.Relay)
|
||||
}
|
||||
|
||||
func (n *Node) startProtocol(peer Peer) {
|
||||
ticker := time.NewTicker(protoTickInterval).C
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker:
|
||||
// Try to sync with the peer if his block height is higher then ours.
|
||||
if peer.Version().StartHeight > n.bc.HeaderHeight() {
|
||||
n.askMoreHeaders(peer)
|
||||
}
|
||||
// Only ask for more peers if the server has the capacity for it.
|
||||
if n.server.hasCapacity() {
|
||||
msg := NewMessage(n.Net, CMDGetAddr, nil)
|
||||
peer.Send(msg)
|
||||
}
|
||||
case <-peer.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When a peer sends out his version we reply with verack after validating
|
||||
// the version.
|
||||
func (n *Node) handleVersionCmd(version *payload.Version, peer Peer) error {
|
||||
msg := NewMessage(n.Net, CMDVerack, nil)
|
||||
peer.Send(msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvCmd handles the forwarded inventory received from the peer.
|
||||
// We will use the getdata message to get more details about the received
|
||||
// inventory.
|
||||
// note: if the server has Relay on false, inventory messages are not received.
|
||||
func (n *Node) handleInvCmd(inv *payload.Inventory, peer Peer) error {
|
||||
if !inv.Type.Valid() {
|
||||
return fmt.Errorf("invalid inventory type received: %s", inv.Type)
|
||||
}
|
||||
if len(inv.Hashes) == 0 {
|
||||
return errors.New("inventory has no hashes")
|
||||
}
|
||||
payload := payload.NewInventory(inv.Type, inv.Hashes)
|
||||
peer.Send(NewMessage(n.Net, CMDGetData, payload))
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleBlockCmd processes the received block received from its peer.
|
||||
func (n *Node) handleBlockCmd(block *core.Block, peer Peer) error {
|
||||
n.logger.Log(
|
||||
"event", "block received",
|
||||
"index", block.Index,
|
||||
"hash", block.Hash(),
|
||||
"tx", len(block.Transactions),
|
||||
)
|
||||
|
||||
return n.bc.AddBlock(block)
|
||||
}
|
||||
|
||||
// After a node sends out the getaddr message its receives a list of known peers
|
||||
// in the network. handleAddrCmd processes that payload.
|
||||
func (n *Node) handleAddrCmd(addressList *payload.AddressList, peer Peer) error {
|
||||
addrs := make([]string, len(addressList.Addrs))
|
||||
for i := 0; i < len(addrs); i++ {
|
||||
addrs[i] = addressList.Addrs[i].Address.String()
|
||||
}
|
||||
n.server.connectToPeers(addrs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// The handleHeadersCmd will process the received headers from its peer.
|
||||
// We call this in a routine cause we may block Peers Send() for to long.
|
||||
func (n *Node) handleHeadersCmd(headers *payload.Headers, peer Peer) error {
|
||||
go func(headers []*core.Header) {
|
||||
if err := n.bc.AddHeaders(headers...); err != nil {
|
||||
n.logger.Log("msg", "failed processing headers", "err", err)
|
||||
return
|
||||
}
|
||||
// The peer will respond with a maximum of 2000 headers in one batch.
|
||||
// We will ask one more batch here if needed. Eventually we will get synced
|
||||
// due to the startProtocol routine that will ask headers every protoTick.
|
||||
if n.bc.HeaderHeight() < peer.Version().StartHeight {
|
||||
n.askMoreHeaders(peer)
|
||||
}
|
||||
}(headers.Hdrs)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// askMoreHeaders will send a getheaders message to the peer.
|
||||
func (n *Node) askMoreHeaders(p Peer) {
|
||||
start := []util.Uint256{n.bc.CurrentHeaderHash()}
|
||||
payload := payload.NewGetBlocks(start, util.Uint256{})
|
||||
p.Send(NewMessage(n.Net, CMDGetHeaders, payload))
|
||||
}
|
||||
|
||||
// blockhain implements the Noder interface.
|
||||
func (n *Node) blockchain() *core.Blockchain { return n.bc }
|
||||
|
||||
// handleProto implements the protoHandler interface.
|
||||
func (n *Node) handleProto(msg *Message, p Peer) {
|
||||
n.protoIn <- messageTuple{
|
||||
msg: msg,
|
||||
peer: p,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Node) handleMessages() {
|
||||
for {
|
||||
t := <-n.protoIn
|
||||
|
||||
var (
|
||||
msg = t.msg
|
||||
p = t.peer
|
||||
err error
|
||||
)
|
||||
|
||||
switch msg.CommandType() {
|
||||
case CMDVersion:
|
||||
version := msg.Payload.(*payload.Version)
|
||||
err = n.handleVersionCmd(version, p)
|
||||
case CMDAddr:
|
||||
addressList := msg.Payload.(*payload.AddressList)
|
||||
err = n.handleAddrCmd(addressList, p)
|
||||
case CMDInv:
|
||||
inventory := msg.Payload.(*payload.Inventory)
|
||||
err = n.handleInvCmd(inventory, p)
|
||||
case CMDBlock:
|
||||
block := msg.Payload.(*core.Block)
|
||||
err = n.handleBlockCmd(block, p)
|
||||
case CMDHeaders:
|
||||
headers := msg.Payload.(*payload.Headers)
|
||||
err = n.handleHeadersCmd(headers, p)
|
||||
case CMDVerack:
|
||||
// Only start the protocol if we got the version and verack
|
||||
// received.
|
||||
if p.Version() != nil {
|
||||
go n.startProtocol(p)
|
||||
}
|
||||
case CMDUnknown:
|
||||
err = errors.New("received non-protocol messgae")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
n.logger.Log(
|
||||
"msg", "failed processing message",
|
||||
"command", msg.CommandType,
|
||||
"err", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
7
pkg/network/node_test.go
Normal file
7
pkg/network/node_test.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package network
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHandleVersion(t *testing.T) {
|
||||
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// AddrWithTime payload
|
||||
type AddrWithTime struct {
|
||||
// Timestamp the node connected to the network.
|
||||
Timestamp uint32
|
||||
Services uint64
|
||||
Addr util.Endpoint
|
||||
}
|
||||
|
||||
// NewAddrWithTime return a pointer to AddrWithTime.
|
||||
func NewAddrWithTime(addr util.Endpoint) *AddrWithTime {
|
||||
return &AddrWithTime{
|
||||
Timestamp: 1337,
|
||||
Services: 1,
|
||||
Addr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
// Size implements the payload interface.
|
||||
func (p *AddrWithTime) Size() uint32 {
|
||||
return 30
|
||||
}
|
||||
|
||||
// DecodeBinary implements the Payload interface.
|
||||
func (p *AddrWithTime) DecodeBinary(r io.Reader) error {
|
||||
err := binary.Read(r, binary.LittleEndian, &p.Timestamp)
|
||||
err = binary.Read(r, binary.LittleEndian, &p.Services)
|
||||
err = binary.Read(r, binary.BigEndian, &p.Addr.IP)
|
||||
err = binary.Read(r, binary.BigEndian, &p.Addr.Port)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// EncodeBinary implements the Payload interface.
|
||||
func (p *AddrWithTime) EncodeBinary(w io.Writer) error {
|
||||
err := binary.Write(w, binary.LittleEndian, p.Timestamp)
|
||||
err = binary.Write(w, binary.LittleEndian, p.Services)
|
||||
err = binary.Write(w, binary.BigEndian, p.Addr.IP)
|
||||
err = binary.Write(w, binary.BigEndian, p.Addr.Port)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AddressList is a list with AddrWithTime.
|
||||
type AddressList struct {
|
||||
Addrs []*AddrWithTime
|
||||
}
|
||||
|
||||
// DecodeBinary implements the Payload interface.
|
||||
func (p *AddressList) DecodeBinary(r io.Reader) error {
|
||||
listLen := util.ReadVarUint(r)
|
||||
|
||||
p.Addrs = make([]*AddrWithTime, listLen)
|
||||
for i := 0; i < int(listLen); i++ {
|
||||
addr := &AddrWithTime{}
|
||||
if err := addr.DecodeBinary(r); err != nil {
|
||||
return err
|
||||
}
|
||||
p.Addrs[i] = addr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeBinary implements the Payload interface.
|
||||
func (p *AddressList) EncodeBinary(w io.Writer) error {
|
||||
// Write the length of the slice
|
||||
util.WriteVarUint(w, uint64(len(p.Addrs)))
|
||||
|
||||
for _, addr := range p.Addrs {
|
||||
if err := addr.EncodeBinary(w); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Size implements the Payloader interface.
|
||||
func (p *AddressList) Size() uint32 {
|
||||
return uint32(len(p.Addrs) * 30)
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
func TestEncodeDecodeAddr(t *testing.T) {
|
||||
e, err := util.EndpointFromString("127.0.0.1:2000")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := NewAddrWithTime(e)
|
||||
buf := new(bytes.Buffer)
|
||||
if err := addr.EncodeBinary(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addrDecode := &AddrWithTime{}
|
||||
if err := addrDecode.DecodeBinary(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(addr, addrDecode) {
|
||||
t.Fatalf("expected both addr payloads to be equal: %v and %v", addr, addrDecode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeAddressList(t *testing.T) {
|
||||
var lenList uint8 = 4
|
||||
addrList := &AddressList{make([]*AddrWithTime, lenList)}
|
||||
for i := 0; i < int(lenList); i++ {
|
||||
e, _ := util.EndpointFromString(fmt.Sprintf("127.0.0.1:200%d", i))
|
||||
addrList.Addrs[i] = NewAddrWithTime(e)
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := addrList.EncodeBinary(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addrListDecode := &AddressList{}
|
||||
if err := addrListDecode.DecodeBinary(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(addrList, addrListDecode) {
|
||||
t.Fatalf("expected both address list payloads to be equal: %v and %v", addrList, addrListDecode)
|
||||
}
|
||||
}
|
85
pkg/network/payload/address.go
Normal file
85
pkg/network/payload/address.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// AddressAndTime payload.
|
||||
type AddressAndTime struct {
|
||||
Timestamp uint32
|
||||
Services uint64
|
||||
Address util.Endpoint
|
||||
}
|
||||
|
||||
// NewAddressAndTime creates a new AddressAndTime object.
|
||||
func NewAddressAndTime(e util.Endpoint, t time.Time) *AddressAndTime {
|
||||
return &AddressAndTime{
|
||||
Timestamp: uint32(t.UTC().Unix()),
|
||||
Services: 1,
|
||||
Address: e,
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeBinary implements the Payload interface.
|
||||
func (p *AddressAndTime) DecodeBinary(r io.Reader) error {
|
||||
if err := binary.Read(r, binary.LittleEndian, &p.Timestamp); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Read(r, binary.LittleEndian, &p.Services); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Read(r, binary.BigEndian, &p.Address.IP); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Read(r, binary.BigEndian, &p.Address.Port)
|
||||
}
|
||||
|
||||
// EncodeBinary implements the Payload interface.
|
||||
func (p *AddressAndTime) EncodeBinary(w io.Writer) error {
|
||||
if err := binary.Write(w, binary.LittleEndian, p.Timestamp); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, p.Services); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.BigEndian, p.Address.IP); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(w, binary.BigEndian, p.Address.Port)
|
||||
}
|
||||
|
||||
// AddressList is a list with AddrAndTime.
|
||||
type AddressList struct {
|
||||
Addrs []*AddressAndTime
|
||||
}
|
||||
|
||||
// DecodeBinary implements the Payload interface.
|
||||
func (p *AddressList) DecodeBinary(r io.Reader) error {
|
||||
listLen := util.ReadVarUint(r)
|
||||
|
||||
p.Addrs = make([]*AddressAndTime, listLen)
|
||||
for i := 0; i < int(listLen); i++ {
|
||||
p.Addrs[i] = &AddressAndTime{}
|
||||
if err := p.Addrs[i].DecodeBinary(r); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeBinary implements the Payload interface.
|
||||
func (p *AddressList) EncodeBinary(w io.Writer) error {
|
||||
if err := util.WriteVarUint(w, uint64(len(p.Addrs))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, addr := range p.Addrs {
|
||||
if err := addr.EncodeBinary(w); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
51
pkg/network/payload/address_test.go
Normal file
51
pkg/network/payload/address_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEncodeDecodeAddress(t *testing.T) {
|
||||
var (
|
||||
e = util.NewEndpoint("127.0.0.1:2000")
|
||||
addr = NewAddressAndTime(e, time.Now())
|
||||
buf = new(bytes.Buffer)
|
||||
)
|
||||
|
||||
if err := addr.EncodeBinary(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addrDecode := &AddressAndTime{}
|
||||
if err := addrDecode.DecodeBinary(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, addr, addrDecode)
|
||||
}
|
||||
|
||||
func TestEncodeDecodeAddressList(t *testing.T) {
|
||||
var lenList uint8 = 4
|
||||
addrList := &AddressList{make([]*AddressAndTime, lenList)}
|
||||
for i := 0; i < int(lenList); i++ {
|
||||
e := util.NewEndpoint(fmt.Sprintf("127.0.0.1:200%d", i))
|
||||
addrList.Addrs[i] = NewAddressAndTime(e, time.Now())
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := addrList.EncodeBinary(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addrListDecode := &AddressList{}
|
||||
if err := addrListDecode.DecodeBinary(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, addrList, addrListDecode)
|
||||
}
|
|
@ -37,6 +37,5 @@ func (p *Headers) EncodeBinary(w io.Writer) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -2,11 +2,11 @@ package payload
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/core"
|
||||
"github.com/CityOfZion/neo-go/pkg/core/transaction"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHeadersEncodeDecode(t *testing.T) {
|
||||
|
@ -50,7 +50,9 @@ func TestHeadersEncodeDecode(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(headers, headersDecode) {
|
||||
t.Fatalf("expected both header payload to be equal %+v and %+v", headers, headersDecode)
|
||||
for i := 0; i < len(headers.Hdrs); i++ {
|
||||
assert.Equal(t, headers.Hdrs[i].Version, headersDecode.Hdrs[i].Version)
|
||||
assert.Equal(t, headers.Hdrs[i].Index, headersDecode.Hdrs[i].Index)
|
||||
assert.Equal(t, headers.Hdrs[i].Script, headersDecode.Hdrs[i].Script)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,8 +35,8 @@ func (i InventoryType) Valid() bool {
|
|||
// List of valid InventoryTypes.
|
||||
const (
|
||||
BlockType InventoryType = 0x01 // 1
|
||||
TXType = 0x02 // 2
|
||||
ConsensusType = 0xe0 // 224
|
||||
TXType InventoryType = 0x02 // 2
|
||||
ConsensusType InventoryType = 0xe0 // 224
|
||||
)
|
||||
|
||||
// Inventory payload
|
||||
|
|
|
@ -44,36 +44,63 @@ func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version {
|
|||
|
||||
// DecodeBinary implements the Payload interface.
|
||||
func (p *Version) DecodeBinary(r io.Reader) error {
|
||||
err := binary.Read(r, binary.LittleEndian, &p.Version)
|
||||
err = binary.Read(r, binary.LittleEndian, &p.Services)
|
||||
err = binary.Read(r, binary.LittleEndian, &p.Timestamp)
|
||||
err = binary.Read(r, binary.LittleEndian, &p.Port)
|
||||
err = binary.Read(r, binary.LittleEndian, &p.Nonce)
|
||||
if err := binary.Read(r, binary.LittleEndian, &p.Version); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Read(r, binary.LittleEndian, &p.Services); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Read(r, binary.LittleEndian, &p.Timestamp); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Read(r, binary.LittleEndian, &p.Port); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Read(r, binary.LittleEndian, &p.Nonce); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var lenUA uint8
|
||||
err = binary.Read(r, binary.LittleEndian, &lenUA)
|
||||
if err := binary.Read(r, binary.LittleEndian, &lenUA); err != nil {
|
||||
return err
|
||||
}
|
||||
p.UserAgent = make([]byte, lenUA)
|
||||
err = binary.Read(r, binary.LittleEndian, &p.UserAgent)
|
||||
|
||||
err = binary.Read(r, binary.LittleEndian, &p.StartHeight)
|
||||
err = binary.Read(r, binary.LittleEndian, &p.Relay)
|
||||
|
||||
return err
|
||||
if err := binary.Read(r, binary.LittleEndian, &p.UserAgent); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Read(r, binary.LittleEndian, &p.StartHeight); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Read(r, binary.LittleEndian, &p.Relay)
|
||||
}
|
||||
|
||||
// EncodeBinary implements the Payload interface.
|
||||
func (p *Version) EncodeBinary(w io.Writer) error {
|
||||
err := binary.Write(w, binary.LittleEndian, p.Version)
|
||||
err = binary.Write(w, binary.LittleEndian, p.Services)
|
||||
err = binary.Write(w, binary.LittleEndian, p.Timestamp)
|
||||
err = binary.Write(w, binary.LittleEndian, p.Port)
|
||||
err = binary.Write(w, binary.LittleEndian, p.Nonce)
|
||||
err = binary.Write(w, binary.LittleEndian, uint8(len(p.UserAgent)))
|
||||
err = binary.Write(w, binary.LittleEndian, p.UserAgent)
|
||||
err = binary.Write(w, binary.LittleEndian, p.StartHeight)
|
||||
err = binary.Write(w, binary.LittleEndian, p.Relay)
|
||||
|
||||
return err
|
||||
if err := binary.Write(w, binary.LittleEndian, p.Version); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, p.Services); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, p.Timestamp); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, p.Port); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, p.Nonce); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, uint8(len(p.UserAgent))); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, p.UserAgent); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, p.StartHeight); err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(w, binary.LittleEndian, p.Relay)
|
||||
}
|
||||
|
||||
// Size implements the payloader interface.
|
||||
|
|
|
@ -5,48 +5,12 @@ import (
|
|||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// Peer is the local representation of a remote node. It's an interface that may
|
||||
// be backed by any concrete transport: local, HTTP, tcp.
|
||||
// A Peer is the local representation of a remote peer.
|
||||
// It's an interface that may be backed by any concrete
|
||||
// transport.
|
||||
type Peer interface {
|
||||
id() uint32
|
||||
addr() util.Endpoint
|
||||
disconnect()
|
||||
Send(*Message) error
|
||||
version() *payload.Version
|
||||
Version() *payload.Version
|
||||
Endpoint() util.Endpoint
|
||||
Send(*Message)
|
||||
Done() chan struct{}
|
||||
}
|
||||
|
||||
// LocalPeer is the simplest kind of peer, mapped to a server in the
|
||||
// same process-space.
|
||||
type LocalPeer struct {
|
||||
s *Server
|
||||
nonce uint32
|
||||
endpoint util.Endpoint
|
||||
pVersion *payload.Version
|
||||
}
|
||||
|
||||
// NewLocalPeer return a LocalPeer.
|
||||
func NewLocalPeer(s *Server) *LocalPeer {
|
||||
e, _ := util.EndpointFromString("1.1.1.1:1111")
|
||||
return &LocalPeer{endpoint: e, s: s}
|
||||
}
|
||||
|
||||
func (p *LocalPeer) Send(msg *Message) error {
|
||||
switch msg.commandType() {
|
||||
case cmdVersion:
|
||||
version := msg.Payload.(*payload.Version)
|
||||
return p.s.handleVersionCmd(version, p)
|
||||
case cmdGetAddr:
|
||||
return p.s.handleGetaddrCmd(msg, p)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Version implements the Peer interface.
|
||||
func (p *LocalPeer) version() *payload.Version {
|
||||
return p.pVersion
|
||||
}
|
||||
|
||||
func (p *LocalPeer) id() uint32 { return p.nonce }
|
||||
func (p *LocalPeer) addr() util.Endpoint { return p.endpoint }
|
||||
func (p *LocalPeer) disconnect() {}
|
||||
|
|
22
pkg/network/protocol.go
Normal file
22
pkg/network/protocol.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"github.com/CityOfZion/neo-go/pkg/core"
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
)
|
||||
|
||||
// A ProtoHandler is an interface that abstract the implementation
|
||||
// of the NEO protocol.
|
||||
type ProtoHandler interface {
|
||||
version() *payload.Version
|
||||
handleProto(*Message, Peer)
|
||||
}
|
||||
|
||||
type protoHandleFunc func(*Message, Peer)
|
||||
|
||||
// Noder is anything that implements the NEO protocol
|
||||
// and can return the Blockchain object.
|
||||
type Noder interface {
|
||||
ProtoHandler
|
||||
blockchain() *core.Blockchain
|
||||
}
|
|
@ -1,129 +0,0 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
rpcPortMainNet = 20332
|
||||
rpcPortTestNet = 10332
|
||||
rpcVersion = "2.0"
|
||||
|
||||
// error response messages
|
||||
methodNotFound = "Method not found"
|
||||
parseError = "Parse error"
|
||||
)
|
||||
|
||||
// Each NEO node has a set of optional APIs for accessing blockchain
|
||||
// data and making things easier for development of blockchain apps.
|
||||
// APIs are provided via JSON-RPC , comm at bottom layer is with http/https protocol.
|
||||
|
||||
// listenHTTP creates an ingress bridge from the outside world to the passed
|
||||
// server, by installing handlers for all the necessary RPCs to the passed mux.
|
||||
func listenHTTP(s *Server, port int) {
|
||||
api := &API{s}
|
||||
p := fmt.Sprintf(":%d", port)
|
||||
s.logger.Printf("serving RPC on %d", port)
|
||||
s.logger.Printf("%s", http.ListenAndServe(p, api))
|
||||
}
|
||||
|
||||
// API serves JSON-RPC.
|
||||
type API struct {
|
||||
s *Server
|
||||
}
|
||||
|
||||
func (s *API) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Official nodes respond a parse error if the method is not POST.
|
||||
// Instead of returning a decent response for this, let's do the same.
|
||||
if r.Method != "POST" {
|
||||
writeError(w, 0, 0, parseError)
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, 0, 0, parseError)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
if req.Version != rpcVersion {
|
||||
writeJSON(w, http.StatusBadRequest, nil)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
case "getconnectioncount":
|
||||
if err := s.getConnectionCount(w, &req); err != nil {
|
||||
writeError(w, 0, 0, parseError)
|
||||
return
|
||||
}
|
||||
case "getblockcount":
|
||||
case "getbestblockhash":
|
||||
default:
|
||||
writeError(w, 0, 0, methodNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// This is an Example on how we could handle incomming RPC requests.
|
||||
func (s *API) getConnectionCount(w http.ResponseWriter, req *Request) error {
|
||||
count := s.s.peerCount()
|
||||
|
||||
resp := ConnectionCountResponse{
|
||||
Version: rpcVersion,
|
||||
Result: count,
|
||||
ID: 1,
|
||||
}
|
||||
|
||||
return writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// writeError returns a JSON error with given parameters. All error HTTP
|
||||
// status codes are 200. According to the official API.
|
||||
func writeError(w http.ResponseWriter, id, code int, msg string) error {
|
||||
resp := RequestError{
|
||||
Version: rpcVersion,
|
||||
ID: id,
|
||||
Error: Error{
|
||||
Code: code,
|
||||
Message: msg,
|
||||
},
|
||||
}
|
||||
|
||||
return writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, v interface{}) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
// Request is an object received through JSON-RPC from the client.
|
||||
type Request struct {
|
||||
Version string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Params []string `json:"params"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
// ConnectionCountResponse ..
|
||||
type ConnectionCountResponse struct {
|
||||
Version string `json:"jsonrpc"`
|
||||
Result int `json:"result"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
// RequestError ..
|
||||
type RequestError struct {
|
||||
Version string `json:"jsonrpc"`
|
||||
ID int `json:"id"`
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
// Error holds information about an RCP error.
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
|
@ -1,354 +1,282 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/core"
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
log "github.com/go-kit/kit/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// node version
|
||||
version = "2.6.0"
|
||||
|
||||
// official ports according to the protocol.
|
||||
portMainNet = 10333
|
||||
portTestNet = 20333
|
||||
maxPeers = 50
|
||||
)
|
||||
|
||||
type messageTuple struct {
|
||||
peer Peer
|
||||
msg *Message
|
||||
var dialTimeout = 4 * time.Second
|
||||
|
||||
// Config holds the server configuration.
|
||||
type Config struct {
|
||||
// MaxPeers it the maximum numbers of peers that can
|
||||
// be connected to the server.
|
||||
MaxPeers int
|
||||
|
||||
// The user agent of the server.
|
||||
UserAgent string
|
||||
|
||||
// The listen address of the TCP server.
|
||||
ListenTCP uint16
|
||||
|
||||
// The listen address of the RPC server.
|
||||
ListenRPC uint16
|
||||
|
||||
// The network mode this server will operate on.
|
||||
// ModePrivNet docker private network.
|
||||
// ModeTestNet NEO test network.
|
||||
// ModeMainNet NEO main network.
|
||||
Net NetMode
|
||||
|
||||
// Relay determins whether the server is forwarding its inventory.
|
||||
Relay bool
|
||||
|
||||
// Seeds are a list of initial nodes used to establish connectivity.
|
||||
Seeds []string
|
||||
|
||||
// Maximum duration a single dial may take.
|
||||
DialTimeout time.Duration
|
||||
}
|
||||
|
||||
// Server is the representation of a full working NEO TCP node.
|
||||
// Server manages all incoming peer connections.
|
||||
type Server struct {
|
||||
logger *log.Logger
|
||||
// id of the server
|
||||
// Config fields may not be modified while the server is running.
|
||||
Config
|
||||
|
||||
// Proto is just about anything that can handle the NEO protocol.
|
||||
// In production enviroments the ProtoHandler is mostly the local node.
|
||||
proto ProtoHandler
|
||||
|
||||
// Unique id of this server.
|
||||
id uint32
|
||||
// the port the TCP listener is listening on.
|
||||
port uint16
|
||||
// userAgent of the server.
|
||||
userAgent string
|
||||
// The "magic" mode the server is currently running on.
|
||||
// This can either be 0x00746e41 or 0x74746e41 for main or test net.
|
||||
// Or 56753 to work with the docker privnet.
|
||||
net NetMode
|
||||
// map that holds all connected peers to this server.
|
||||
peers map[Peer]bool
|
||||
// channel for handling new registerd peers.
|
||||
register chan Peer
|
||||
// channel for safely removing and disconnecting peers.
|
||||
unregister chan Peer
|
||||
// channel for coordinating messages.
|
||||
message chan messageTuple
|
||||
// channel used to gracefull shutdown the server.
|
||||
quit chan struct{}
|
||||
// Whether this server will receive and forward messages.
|
||||
relay bool
|
||||
// TCP listener of the server
|
||||
|
||||
logger log.Logger
|
||||
listener net.Listener
|
||||
// channel for safely responding the number of current connected peers.
|
||||
peerCountCh chan peerCount
|
||||
// a list of hashes that
|
||||
knownHashes protectedHashmap
|
||||
// The blockchain.
|
||||
bc *core.Blockchain
|
||||
|
||||
register chan Peer
|
||||
unregister chan Peer
|
||||
|
||||
badAddrOp chan func(map[string]bool)
|
||||
badAddrOpDone chan struct{}
|
||||
|
||||
peerOp chan func(map[Peer]bool)
|
||||
peerOpDone chan struct{}
|
||||
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// TODO: Maybe util is a better place for such data types.
|
||||
type protectedHashmap struct {
|
||||
*sync.RWMutex
|
||||
hashes map[util.Uint256]bool
|
||||
}
|
||||
|
||||
func (m protectedHashmap) add(h util.Uint256) bool {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if _, ok := m.hashes[h]; !ok {
|
||||
m.hashes[h] = true
|
||||
return true
|
||||
// NewServer returns a new Server object created from the
|
||||
// given config.
|
||||
func NewServer(cfg Config) *Server {
|
||||
if cfg.MaxPeers == 0 {
|
||||
cfg.MaxPeers = maxPeers
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m protectedHashmap) remove(h util.Uint256) bool {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if _, ok := m.hashes[h]; ok {
|
||||
delete(m.hashes, h)
|
||||
return true
|
||||
if cfg.Net == 0 {
|
||||
cfg.Net = ModeTestNet
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m protectedHashmap) has(h util.Uint256) bool {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
_, ok := m.hashes[h]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// NewServer returns a pointer to a new server.
|
||||
func NewServer(net NetMode) *Server {
|
||||
logger := log.New(os.Stdout, "[NEO SERVER] :: ", 0)
|
||||
|
||||
if net != ModeTestNet && net != ModeMainNet && net != ModePrivNet {
|
||||
logger.Fatalf("invalid network mode %d", net)
|
||||
if cfg.DialTimeout == 0 {
|
||||
cfg.DialTimeout = dialTimeout
|
||||
}
|
||||
|
||||
// For now I will hard code a genesis block of the docker privnet container.
|
||||
startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099")
|
||||
logger := log.NewLogfmtLogger(os.Stderr)
|
||||
logger = log.With(logger, "component", "server")
|
||||
|
||||
s := &Server{
|
||||
id: util.RandUint32(1111111, 9999999),
|
||||
userAgent: fmt.Sprintf("/NEO:%s/", version),
|
||||
logger: logger,
|
||||
peers: make(map[Peer]bool),
|
||||
register: make(chan Peer),
|
||||
unregister: make(chan Peer),
|
||||
message: make(chan messageTuple),
|
||||
relay: true, // currently relay is not handled.
|
||||
net: net,
|
||||
quit: make(chan struct{}),
|
||||
peerCountCh: make(chan peerCount),
|
||||
bc: core.NewBlockchain(core.NewMemoryStore(), logger, startHash),
|
||||
Config: cfg,
|
||||
logger: logger,
|
||||
id: util.RandUint32(1000000, 9999999),
|
||||
quit: make(chan struct{}, 1),
|
||||
register: make(chan Peer),
|
||||
unregister: make(chan Peer),
|
||||
badAddrOp: make(chan func(map[string]bool)),
|
||||
badAddrOpDone: make(chan struct{}),
|
||||
peerOp: make(chan func(map[Peer]bool)),
|
||||
peerOpDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
s.proto = newNode(s, cfg)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start run's the server.
|
||||
// TODO: server should be initialized with a config.
|
||||
func (s *Server) Start(opts StartOpts) {
|
||||
s.port = uint16(opts.TCP)
|
||||
|
||||
fmt.Println(logo())
|
||||
fmt.Println(string(s.userAgent))
|
||||
fmt.Println("")
|
||||
s.logger.Printf("NET: %s - TCP: %d - RELAY: %v - ID: %d",
|
||||
s.net, int(s.port), s.relay, s.id)
|
||||
|
||||
go listenTCP(s, opts.TCP)
|
||||
|
||||
if opts.RPC > 0 {
|
||||
go listenHTTP(s, opts.RPC)
|
||||
}
|
||||
|
||||
if len(opts.Seeds) > 0 {
|
||||
connectToSeeds(s, opts.Seeds)
|
||||
}
|
||||
|
||||
s.loop()
|
||||
}
|
||||
|
||||
// Stop the server, attemping a gracefull shutdown.
|
||||
func (s *Server) Stop() { s.quit <- struct{}{} }
|
||||
|
||||
// shutdown the server, disconnecting all peers.
|
||||
func (s *Server) shutdown() {
|
||||
s.logger.Println("attemping a quitefull shutdown.")
|
||||
s.listener.Close()
|
||||
|
||||
// disconnect and remove all connected peers.
|
||||
for peer := range s.peers {
|
||||
peer.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) loop() {
|
||||
for {
|
||||
select {
|
||||
// When a new connection is been established, (by this server or remote node)
|
||||
// its peer will be received on this channel.
|
||||
// Any peer registration must happen via this channel.
|
||||
case peer := <-s.register:
|
||||
if len(s.peers) < maxPeers {
|
||||
s.logger.Printf("peer registered from address %s", peer.addr())
|
||||
s.peers[peer] = true
|
||||
|
||||
if err := s.handlePeerConnected(peer); err != nil {
|
||||
s.logger.Printf("failed handling peer connection: %s", err)
|
||||
peer.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
// unregister safely deletes a peer. For disconnecting peers use the
|
||||
// disconnect() method on the peer, it will call unregister and terminates its routines.
|
||||
case peer := <-s.unregister:
|
||||
if _, ok := s.peers[peer]; ok {
|
||||
delete(s.peers, peer)
|
||||
s.logger.Printf("peer %s disconnected", peer.addr())
|
||||
}
|
||||
|
||||
case t := <-s.peerCountCh:
|
||||
t.count <- len(s.peers)
|
||||
|
||||
case <-s.quit:
|
||||
s.shutdown()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When a new peer is connected we send our version.
|
||||
// No further communication should be made before both sides has received
|
||||
// the versions of eachother.
|
||||
func (s *Server) handlePeerConnected(p Peer) error {
|
||||
// TODO: get the blockheight of this server once core implemented this.
|
||||
payload := payload.NewVersion(s.id, s.port, s.userAgent, s.bc.HeaderHeight(), s.relay)
|
||||
msg := newMessage(s.net, cmdVersion, payload)
|
||||
return p.Send(msg)
|
||||
}
|
||||
|
||||
func (s *Server) handleVersionCmd(version *payload.Version, p Peer) error {
|
||||
if s.id == version.Nonce {
|
||||
return errors.New("identical nonce")
|
||||
}
|
||||
if p.addr().Port != version.Port {
|
||||
return fmt.Errorf("port mismatch: %d and %d", version.Port, p.addr().Port)
|
||||
}
|
||||
|
||||
return p.Send(
|
||||
newMessage(s.net, cmdVerack, nil),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) handleGetaddrCmd(msg *Message, p Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The node can broadcast the object information it owns by this message.
|
||||
// The message can be sent automatically or can be used to answer getbloks messages.
|
||||
func (s *Server) handleInvCmd(inv *payload.Inventory, p Peer) error {
|
||||
if !inv.Type.Valid() {
|
||||
return fmt.Errorf("invalid inventory type %s", inv.Type)
|
||||
}
|
||||
if len(inv.Hashes) == 0 {
|
||||
return errors.New("inventory should have at least 1 hash got 0")
|
||||
}
|
||||
|
||||
// todo: only grab the hashes that we dont know.
|
||||
|
||||
payload := payload.NewInventory(inv.Type, inv.Hashes)
|
||||
resp := newMessage(s.net, cmdGetData, payload)
|
||||
|
||||
return p.Send(resp)
|
||||
}
|
||||
|
||||
// handleBlockCmd processes the received block.
|
||||
func (s *Server) handleBlockCmd(block *core.Block, p Peer) error {
|
||||
hash, err := block.Hash()
|
||||
func (s *Server) createListener() error {
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", s.ListenTCP))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Printf("new block: index %d hash %s", block.Index, hash)
|
||||
|
||||
s.listener = ln
|
||||
return nil
|
||||
}
|
||||
|
||||
// After receiving the getaddr message, the node returns an addr message as response
|
||||
// and provides information about the known nodes on the network.
|
||||
func (s *Server) handleAddrCmd(addrList *payload.AddressList, p Peer) error {
|
||||
for _, addr := range addrList.Addrs {
|
||||
if !s.peerAlreadyConnected(addr.Addr) {
|
||||
// TODO: this is not transport abstracted.
|
||||
go connectToRemoteNode(s, addr.Addr.String())
|
||||
func (s *Server) listenTCP() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
s.logger.Log("msg", "conn read error", "err", err)
|
||||
break
|
||||
}
|
||||
go s.setupConnection(conn)
|
||||
}
|
||||
s.Quit()
|
||||
}
|
||||
|
||||
func (s *Server) setupConnection(conn net.Conn) {
|
||||
if !s.hasCapacity() {
|
||||
s.logger.Log("msg", "server reached maximum capacity")
|
||||
return
|
||||
}
|
||||
|
||||
p := NewTCPPeer(conn, s.proto.handleProto)
|
||||
s.register <- p
|
||||
if err := p.run(); err != nil {
|
||||
s.unregister <- p
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) connectToPeers(addrs ...string) {
|
||||
for _, addr := range addrs {
|
||||
if s.hasCapacity() && s.canConnectWith(addr) {
|
||||
go func(addr string) {
|
||||
conn, err := net.DialTimeout("tcp", addr, s.DialTimeout)
|
||||
if err != nil {
|
||||
s.badAddrOp <- func(badAddrs map[string]bool) {
|
||||
badAddrs[addr] = true
|
||||
}
|
||||
<-s.badAddrOpDone
|
||||
return
|
||||
}
|
||||
go s.setupConnection(conn)
|
||||
}(addr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle the headers received from the remote after we asked for headers with the
|
||||
// "getheaders" message.
|
||||
func (s *Server) handleHeadersCmd(headers *payload.Headers, p Peer) error {
|
||||
// Set a deadline for adding headers?
|
||||
go func(ctx context.Context, headers []*core.Header) {
|
||||
if err := s.bc.AddHeaders(headers...); err != nil {
|
||||
s.logger.Printf("failed to add headers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Ask more headers if we are not in sync with the peer.
|
||||
if s.bc.HeaderHeight() < p.version().StartHeight {
|
||||
if err := s.askMoreHeaders(p); err != nil {
|
||||
s.logger.Printf("getheaders RPC failed: %s", err)
|
||||
return
|
||||
func (s *Server) canConnectWith(addr string) bool {
|
||||
canConnect := true
|
||||
s.peerOp <- func(peers map[Peer]bool) {
|
||||
for peer := range peers {
|
||||
if peer.Endpoint().String() == addr {
|
||||
canConnect = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}(context.TODO(), headers.Hdrs)
|
||||
}
|
||||
<-s.peerOpDone
|
||||
if !canConnect {
|
||||
return false
|
||||
}
|
||||
|
||||
return nil
|
||||
s.badAddrOp <- func(badAddrs map[string]bool) {
|
||||
_, ok := badAddrs[addr]
|
||||
canConnect = !ok
|
||||
}
|
||||
<-s.badAddrOpDone
|
||||
return canConnect
|
||||
}
|
||||
|
||||
// Ask the peer for more headers We use the current block hash as start.
|
||||
func (s *Server) askMoreHeaders(p Peer) error {
|
||||
start := []util.Uint256{s.bc.CurrentHeaderHash()}
|
||||
payload := payload.NewGetBlocks(start, util.Uint256{})
|
||||
msg := newMessage(s.net, cmdGetHeaders, payload)
|
||||
|
||||
return p.Send(msg)
|
||||
func (s *Server) hasCapacity() bool {
|
||||
return s.PeerCount() != s.MaxPeers
|
||||
}
|
||||
|
||||
// check if the addr is already connected to the server.
|
||||
func (s *Server) peerAlreadyConnected(addr net.Addr) bool {
|
||||
// TODO: Dont try to connect with ourselfs.
|
||||
for peer := range s.peers {
|
||||
if peer.addr().String() == addr.String() {
|
||||
return true
|
||||
func (s *Server) sendVersion(peer Peer) {
|
||||
peer.Send(NewMessage(s.Net, CMDVersion, s.proto.version()))
|
||||
}
|
||||
|
||||
func (s *Server) run() {
|
||||
var (
|
||||
ticker = time.NewTicker(30 * time.Second).C
|
||||
peers = make(map[Peer]bool)
|
||||
badAddrs = make(map[string]bool)
|
||||
)
|
||||
|
||||
for {
|
||||
select {
|
||||
case op := <-s.badAddrOp:
|
||||
op(badAddrs)
|
||||
s.badAddrOpDone <- struct{}{}
|
||||
case op := <-s.peerOp:
|
||||
op(peers)
|
||||
s.peerOpDone <- struct{}{}
|
||||
case p := <-s.register:
|
||||
peers[p] = true
|
||||
// When a new peer connection is established, we send
|
||||
// out our version immediately.
|
||||
s.sendVersion(p)
|
||||
s.logger.Log("event", "peer connected", "endpoint", p.Endpoint())
|
||||
case p := <-s.unregister:
|
||||
delete(peers, p)
|
||||
s.logger.Log("event", "peer disconnected", "endpoint", p.Endpoint())
|
||||
case <-ticker:
|
||||
s.printState()
|
||||
case <-s.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: Quit this routine if the peer is disconnected.
|
||||
func (s *Server) startProtocol(p Peer) {
|
||||
if s.bc.HeaderHeight() < p.version().StartHeight {
|
||||
s.askMoreHeaders(p)
|
||||
}
|
||||
for {
|
||||
getaddrMsg := newMessage(s.net, cmdGetAddr, nil)
|
||||
p.Send(getaddrMsg)
|
||||
|
||||
time.Sleep(30 * time.Second)
|
||||
// PeerCount returns the number of current connected peers.
|
||||
func (s *Server) PeerCount() (n int) {
|
||||
s.peerOp <- func(peers map[Peer]bool) {
|
||||
n = len(peers)
|
||||
}
|
||||
<-s.peerOpDone
|
||||
return
|
||||
}
|
||||
|
||||
type peerCount struct {
|
||||
count chan int
|
||||
}
|
||||
func (s *Server) Start() error {
|
||||
fmt.Println(logo())
|
||||
fmt.Println("")
|
||||
s.printConfiguration()
|
||||
|
||||
// peerCount returns the number of connected peers to this server.
|
||||
func (s *Server) peerCount() int {
|
||||
ch := peerCount{
|
||||
count: make(chan int),
|
||||
if err := s.createListener(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.peerCountCh <- ch
|
||||
|
||||
return <-ch.count
|
||||
go s.run()
|
||||
go s.listenTCP()
|
||||
go s.connectToPeers(s.Seeds...)
|
||||
select {}
|
||||
}
|
||||
|
||||
// StartOpts holds the server configuration.
|
||||
type StartOpts struct {
|
||||
// tcp port
|
||||
TCP int
|
||||
// slice of peer addresses the server will connect to
|
||||
Seeds []string
|
||||
// JSON-RPC port. If 0 no RPC handler will be attached.
|
||||
RPC int
|
||||
func (s *Server) Quit() {
|
||||
s.quit <- struct{}{}
|
||||
}
|
||||
|
||||
func (s *Server) printState() {
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0)
|
||||
fmt.Fprintf(w, "connected peers:\t%d/%d\n", s.PeerCount(), s.MaxPeers)
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
func (s *Server) printConfiguration() {
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0)
|
||||
fmt.Fprintf(w, "user agent:\t%s\n", s.UserAgent)
|
||||
fmt.Fprintf(w, "id:\t%d\n", s.id)
|
||||
fmt.Fprintf(w, "network:\t%s\n", s.Net)
|
||||
fmt.Fprintf(w, "listen TCP:\t%d\n", s.ListenTCP)
|
||||
fmt.Fprintf(w, "listen RPC:\t%d\n", s.ListenRPC)
|
||||
fmt.Fprintf(w, "relay:\t%v\n", s.Relay)
|
||||
fmt.Fprintf(w, "max peers:\t%d\n", s.MaxPeers)
|
||||
chainer := s.proto.(Noder)
|
||||
fmt.Fprintf(w, "current height:\t%d\n", chainer.blockchain().HeaderHeight())
|
||||
fmt.Fprintln(w, "")
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
func logo() string {
|
||||
|
|
|
@ -1,60 +1,86 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
log "github.com/go-kit/kit/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TODO this should be moved to localPeer test.
|
||||
func TestRegisterPeer(t *testing.T) {
|
||||
s := newTestServer()
|
||||
go s.run()
|
||||
|
||||
func TestHandleVersionFailWrongPort(t *testing.T) {
|
||||
s := NewServer(ModePrivNet)
|
||||
go s.loop()
|
||||
assert.NotZero(t, s.id)
|
||||
assert.Zero(t, s.PeerCount())
|
||||
|
||||
p := NewLocalPeer(s)
|
||||
lenPeers := 10
|
||||
for i := 0; i < lenPeers; i++ {
|
||||
s.register <- newTestPeer()
|
||||
}
|
||||
assert.Equal(t, lenPeers, s.PeerCount())
|
||||
}
|
||||
|
||||
version := payload.NewVersion(1337, 1, "/NEO:0.0.0/", 0, true)
|
||||
if err := s.handleVersionCmd(version, p); err == nil {
|
||||
t.Fatal("expected error got nil")
|
||||
func TestUnregisterPeer(t *testing.T) {
|
||||
s := newTestServer()
|
||||
go s.run()
|
||||
|
||||
peer := newTestPeer()
|
||||
s.register <- peer
|
||||
s.register <- newTestPeer()
|
||||
s.register <- newTestPeer()
|
||||
assert.Equal(t, 3, s.PeerCount())
|
||||
|
||||
s.unregister <- peer
|
||||
assert.Equal(t, 2, s.PeerCount())
|
||||
}
|
||||
|
||||
type testNode struct{}
|
||||
|
||||
func (t testNode) version() *payload.Version {
|
||||
return &payload.Version{}
|
||||
}
|
||||
|
||||
func (t testNode) handleProto(msg *Message, p Peer) {}
|
||||
|
||||
func newTestServer() *Server {
|
||||
return &Server{
|
||||
logger: log.NewLogfmtLogger(os.Stderr),
|
||||
id: util.RandUint32(1000000, 9999999),
|
||||
quit: make(chan struct{}, 1),
|
||||
register: make(chan Peer),
|
||||
unregister: make(chan Peer),
|
||||
badAddrOp: make(chan func(map[string]bool)),
|
||||
badAddrOpDone: make(chan struct{}),
|
||||
peerOp: make(chan func(map[Peer]bool)),
|
||||
peerOpDone: make(chan struct{}),
|
||||
proto: testNode{},
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleVersionFailIdenticalNonce(t *testing.T) {
|
||||
s := NewServer(ModePrivNet)
|
||||
go s.loop()
|
||||
type testPeer struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
p := NewLocalPeer(s)
|
||||
|
||||
version := payload.NewVersion(s.id, 1, "/NEO:0.0.0/", 0, true)
|
||||
if err := s.handleVersionCmd(version, p); err == nil {
|
||||
t.Fatal("expected error got nil")
|
||||
func newTestPeer() testPeer {
|
||||
return testPeer{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleVersion(t *testing.T) {
|
||||
s := NewServer(ModePrivNet)
|
||||
go s.loop()
|
||||
|
||||
p := NewLocalPeer(s)
|
||||
|
||||
version := payload.NewVersion(1337, p.addr().Port, "/NEO:0.0.0/", 0, true)
|
||||
if err := s.handleVersionCmd(version, p); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
func (p testPeer) Version() *payload.Version {
|
||||
return &payload.Version{}
|
||||
}
|
||||
|
||||
func TestHandleAddrCmd(t *testing.T) {
|
||||
// todo
|
||||
func (p testPeer) Endpoint() util.Endpoint {
|
||||
return util.Endpoint{}
|
||||
}
|
||||
|
||||
func TestHandleGetAddrCmd(t *testing.T) {
|
||||
// todo
|
||||
}
|
||||
func (p testPeer) Send(msg *Message) {}
|
||||
|
||||
func TestHandleInv(t *testing.T) {
|
||||
// todo
|
||||
}
|
||||
func TestHandleBlockCmd(t *testing.T) {
|
||||
// todo
|
||||
func (p testPeer) Done() chan struct{} {
|
||||
return p.done
|
||||
}
|
||||
|
|
|
@ -1,254 +0,0 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/core"
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
func listenTCP(s *Server, port int) error {
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.listener = ln
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go handleConnection(s, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func connectToRemoteNode(s *Server, address string) {
|
||||
conn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
s.logger.Printf("failed to connect to remote node %s", address)
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
go handleConnection(s, conn)
|
||||
}
|
||||
|
||||
func connectToSeeds(s *Server, addrs []string) {
|
||||
for _, addr := range addrs {
|
||||
go connectToRemoteNode(s, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func handleConnection(s *Server, conn net.Conn) {
|
||||
peer := NewTCPPeer(conn, s)
|
||||
s.register <- peer
|
||||
|
||||
// remove the peer from connected peers and cleanup the connection.
|
||||
defer func() {
|
||||
peer.disconnect()
|
||||
}()
|
||||
|
||||
// Start a goroutine that will handle all outgoing messages.
|
||||
go peer.writeLoop()
|
||||
// Start a goroutine that will handle all incomming messages.
|
||||
go handleMessage(s, peer)
|
||||
|
||||
// Read from the connection and decode it into a Message ready for processing.
|
||||
for {
|
||||
msg := &Message{}
|
||||
if err := msg.decode(conn); err != nil {
|
||||
s.logger.Printf("decode error: %s", err)
|
||||
break
|
||||
}
|
||||
|
||||
peer.receive <- msg
|
||||
}
|
||||
}
|
||||
|
||||
// handleMessage multiplexes the message received from a TCP connection to a server command.
|
||||
func handleMessage(s *Server, p *TCPPeer) {
|
||||
var err error
|
||||
|
||||
for {
|
||||
msg := <-p.receive
|
||||
command := msg.commandType()
|
||||
|
||||
// s.logger.Printf("IN :: %d :: %s :: %v", p.id(), command, msg)
|
||||
|
||||
switch command {
|
||||
case cmdVersion:
|
||||
version := msg.Payload.(*payload.Version)
|
||||
if err = s.handleVersionCmd(version, p); err != nil {
|
||||
break
|
||||
}
|
||||
p.nonce = version.Nonce
|
||||
p.pVersion = version
|
||||
|
||||
// When a node receives a connection request, it declares its version immediately.
|
||||
// There will be no other communication until both sides are getting versions of each other.
|
||||
// When a node receives the version message, it replies to a verack as a response immediately.
|
||||
// NOTE: The current official NEO nodes dont mimic this behaviour. There is small chance that the
|
||||
// official nodes will not respond directly with a verack after we sended our version.
|
||||
// is this a bug? - anthdm 02/02/2018
|
||||
msgVerack := <-p.receive
|
||||
if msgVerack.commandType() != cmdVerack {
|
||||
err = errors.New("expected verack after sended out version")
|
||||
break
|
||||
}
|
||||
|
||||
// start the protocol
|
||||
go s.startProtocol(p)
|
||||
case cmdAddr:
|
||||
addrList := msg.Payload.(*payload.AddressList)
|
||||
err = s.handleAddrCmd(addrList, p)
|
||||
case cmdGetAddr:
|
||||
err = s.handleGetaddrCmd(msg, p)
|
||||
case cmdInv:
|
||||
inv := msg.Payload.(*payload.Inventory)
|
||||
err = s.handleInvCmd(inv, p)
|
||||
case cmdBlock:
|
||||
block := msg.Payload.(*core.Block)
|
||||
err = s.handleBlockCmd(block, p)
|
||||
case cmdConsensus:
|
||||
case cmdTX:
|
||||
case cmdVerack:
|
||||
// If we receive a verack here we disconnect. We already handled the verack
|
||||
// when we sended our version.
|
||||
err = errors.New("verack already received")
|
||||
case cmdGetHeaders:
|
||||
case cmdGetBlocks:
|
||||
case cmdGetData:
|
||||
case cmdHeaders:
|
||||
headers := msg.Payload.(*payload.Headers)
|
||||
err = s.handleHeadersCmd(headers, p)
|
||||
default:
|
||||
// This command is unknown by the server.
|
||||
err = fmt.Errorf("unknown command received %v", msg.Command)
|
||||
break
|
||||
}
|
||||
|
||||
// catch all errors here and disconnect.
|
||||
if err != nil {
|
||||
s.logger.Printf("processing message failed: %s", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnect the peer when breaked out of the loop.
|
||||
p.disconnect()
|
||||
}
|
||||
|
||||
type sendTuple struct {
|
||||
msg *Message
|
||||
err chan error
|
||||
}
|
||||
|
||||
// TCPPeer represents a remote node, backed by TCP transport.
|
||||
type TCPPeer struct {
|
||||
s *Server
|
||||
// nonce (id) of the peer.
|
||||
nonce uint32
|
||||
// underlying TCP connection
|
||||
conn net.Conn
|
||||
// host and port information about this peer.
|
||||
endpoint util.Endpoint
|
||||
// channel to coordinate messages writen back to the connection.
|
||||
send chan sendTuple
|
||||
// channel to receive from underlying connection.
|
||||
receive chan *Message
|
||||
// the version sended out by the peer when connected.
|
||||
pVersion *payload.Version
|
||||
}
|
||||
|
||||
// NewTCPPeer returns a pointer to a TCP Peer.
|
||||
func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer {
|
||||
e, _ := util.EndpointFromString(conn.RemoteAddr().String())
|
||||
|
||||
return &TCPPeer{
|
||||
conn: conn,
|
||||
send: make(chan sendTuple),
|
||||
receive: make(chan *Message),
|
||||
endpoint: e,
|
||||
s: s,
|
||||
}
|
||||
}
|
||||
|
||||
// Send needed to implement the network.Peer interface
|
||||
// and provide the functionality to send a message to
|
||||
// the current peer.
|
||||
func (p *TCPPeer) Send(msg *Message) error {
|
||||
t := sendTuple{
|
||||
msg: msg,
|
||||
err: make(chan error),
|
||||
}
|
||||
|
||||
p.send <- t
|
||||
|
||||
return <-t.err
|
||||
}
|
||||
|
||||
func (p *TCPPeer) version() *payload.Version {
|
||||
return p.pVersion
|
||||
}
|
||||
|
||||
// id implements the peer interface
|
||||
func (p *TCPPeer) id() uint32 {
|
||||
return p.nonce
|
||||
}
|
||||
|
||||
// endpoint implements the peer interface
|
||||
func (p *TCPPeer) addr() util.Endpoint {
|
||||
return p.endpoint
|
||||
}
|
||||
|
||||
// disconnect disconnects the peer, cleaning up all its resources.
|
||||
// 3 goroutines needs to be cleanup (writeLoop, handleConnection and handleMessage)
|
||||
func (p *TCPPeer) disconnect() {
|
||||
select {
|
||||
case <-p.send:
|
||||
case <-p.receive:
|
||||
default:
|
||||
close(p.send)
|
||||
close(p.receive)
|
||||
p.s.unregister <- p
|
||||
p.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// writeLoop writes messages to the underlying TCP connection.
|
||||
// A goroutine writeLoop is started for each connection.
|
||||
// There should be at most one writer to a connection executing
|
||||
// all writes from this goroutine.
|
||||
func (p *TCPPeer) writeLoop() {
|
||||
// clean up the connection.
|
||||
defer func() {
|
||||
p.disconnect()
|
||||
}()
|
||||
|
||||
// resuse this buffer
|
||||
buf := new(bytes.Buffer)
|
||||
for {
|
||||
t := <-p.send
|
||||
if t.msg == nil {
|
||||
break // send probably closed.
|
||||
}
|
||||
|
||||
// p.s.logger.Printf("OUT :: %s :: %+v", t.msg.commandType(), t.msg.Payload)
|
||||
|
||||
if err := t.msg.encode(buf); err != nil {
|
||||
t.err <- err
|
||||
}
|
||||
_, err := p.conn.Write(buf.Bytes())
|
||||
t.err <- err
|
||||
|
||||
buf.Reset()
|
||||
}
|
||||
}
|
134
pkg/network/tcp_peer.go
Normal file
134
pkg/network/tcp_peer.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
log "github.com/go-kit/kit/log"
|
||||
)
|
||||
|
||||
// TCPPeer represents a connected remote node in the
|
||||
// network over TCP.
|
||||
type TCPPeer struct {
|
||||
// The endpoint of the peer.
|
||||
endpoint util.Endpoint
|
||||
|
||||
// underlying connection.
|
||||
conn net.Conn
|
||||
|
||||
// The version the peer declared when connecting.
|
||||
version *payload.Version
|
||||
|
||||
// connectedAt is the timestamp the peer connected to
|
||||
// the network.
|
||||
connectedAt time.Time
|
||||
|
||||
// handleProto is the handler that will handle the
|
||||
// incoming message along with its peer.
|
||||
handleProto protoHandleFunc
|
||||
|
||||
// Done is used to broadcast this peer has stopped running
|
||||
// and should be removed as reference.
|
||||
done chan struct{}
|
||||
send chan *Message
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// NewTCPPeer creates a new peer from a TCP connection.
|
||||
func NewTCPPeer(conn net.Conn, fun protoHandleFunc) *TCPPeer {
|
||||
e := util.NewEndpoint(conn.RemoteAddr().String())
|
||||
logger := log.NewLogfmtLogger(os.Stderr)
|
||||
logger = log.With(logger, "component", "peer", "endpoint", e)
|
||||
|
||||
return &TCPPeer{
|
||||
endpoint: e,
|
||||
conn: conn,
|
||||
done: make(chan struct{}),
|
||||
send: make(chan *Message),
|
||||
logger: logger,
|
||||
connectedAt: time.Now().UTC(),
|
||||
handleProto: fun,
|
||||
}
|
||||
}
|
||||
|
||||
// Version implements the Peer interface.
|
||||
func (p *TCPPeer) Version() *payload.Version {
|
||||
return p.version
|
||||
}
|
||||
|
||||
// Endpoint implements the Peer interface.
|
||||
func (p *TCPPeer) Endpoint() util.Endpoint {
|
||||
return p.endpoint
|
||||
}
|
||||
|
||||
// Send implements the Peer interface.
|
||||
func (p *TCPPeer) Send(msg *Message) {
|
||||
p.send <- msg
|
||||
}
|
||||
|
||||
// Done implemnets the Peer interface.
|
||||
func (p *TCPPeer) Done() chan struct{} {
|
||||
return p.done
|
||||
}
|
||||
|
||||
func (p *TCPPeer) run() error {
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go p.readLoop(errCh)
|
||||
go p.writeLoop(errCh)
|
||||
|
||||
err := <-errCh
|
||||
p.logger.Log("err", err)
|
||||
p.cleanup()
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *TCPPeer) readLoop(errCh chan error) {
|
||||
for {
|
||||
msg := &Message{}
|
||||
if err := msg.decode(p.conn); err != nil {
|
||||
errCh <- err
|
||||
break
|
||||
}
|
||||
p.handleMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TCPPeer) writeLoop(errCh chan error) {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
for {
|
||||
msg := <-p.send
|
||||
if err := msg.encode(buf); err != nil {
|
||||
errCh <- err
|
||||
break
|
||||
}
|
||||
if _, err := p.conn.Write(buf.Bytes()); err != nil {
|
||||
errCh <- err
|
||||
break
|
||||
}
|
||||
buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TCPPeer) cleanup() {
|
||||
p.conn.Close()
|
||||
close(p.send)
|
||||
p.done <- struct{}{}
|
||||
}
|
||||
|
||||
func (p *TCPPeer) handleMessage(msg *Message) {
|
||||
switch msg.CommandType() {
|
||||
case CMDVersion:
|
||||
version := msg.Payload.(*payload.Version)
|
||||
p.version = version
|
||||
p.handleProto(msg, p)
|
||||
default:
|
||||
p.handleProto(msg, p)
|
||||
}
|
||||
}
|
|
@ -12,14 +12,11 @@ type Endpoint struct {
|
|||
Port uint16
|
||||
}
|
||||
|
||||
// EndpointFromString returns an Endpoint from the given string.
|
||||
// For now this only handles the most simple hostport form.
|
||||
// e.g. 127.0.0.1:3000
|
||||
// This should be enough to work with for now.
|
||||
func EndpointFromString(s string) (Endpoint, error) {
|
||||
// NewEndpoint creates an Endpoint from the given string.
|
||||
func NewEndpoint(s string) (e Endpoint) {
|
||||
hostPort := strings.Split(s, ":")
|
||||
if len(hostPort) != 2 {
|
||||
return Endpoint{}, fmt.Errorf("invalid address string: %s", s)
|
||||
return e
|
||||
}
|
||||
host := hostPort[0]
|
||||
port := hostPort[1]
|
||||
|
@ -36,7 +33,7 @@ func EndpointFromString(s string) (Endpoint, error) {
|
|||
|
||||
p, _ := strconv.Atoi(port)
|
||||
|
||||
return Endpoint{buf, uint16(p)}, nil
|
||||
return Endpoint{buf, uint16(p)}
|
||||
}
|
||||
|
||||
// Network implements the net.Addr interface.
|
||||
|
|
|
@ -34,7 +34,7 @@ func Uint256DecodeBytes(b []byte) (u Uint256, err error) {
|
|||
return u, nil
|
||||
}
|
||||
|
||||
// ToSlice returns a byte slice representation of u.
|
||||
// Bytes returns a byte slice representation of u.
|
||||
func (u Uint256) Bytes() []byte {
|
||||
b := make([]byte, uint256Size)
|
||||
for i := 0; i < uint256Size; i++ {
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// The WIF network version used to decode and encode WIF keys.
|
||||
// WIFVersion is the version used to decode and encode WIF keys.
|
||||
WIFVersion = 0x80
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue