Node improvements (#47)

* block partial persist

* replaced refactored files with old one.

* removed gokit/log from deps

* Tweaks to not overburden remote nodes with getheaders/getblocks

* Changed Transporter interface to not take the server as argument due to a cause of race warning from the compiler

* started server test suite

* more test + return errors from message handlers

* removed --race from build

* Little improvements.
This commit is contained in:
Anthony De Meulemeester 2018-03-14 10:36:59 +01:00 committed by GitHub
parent dca1865a64
commit aa4bc1b6e8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
42 changed files with 1187 additions and 892 deletions

0
View file

7
.gitignore vendored
View file

@ -15,7 +15,6 @@ vendor/
bin/
# text editors
# vscode
.vscode/*
!.vscode/settings.json
@ -25,3 +24,9 @@ bin/
# anthdm todolists
/pkg/vm/compiler/todo.md
# leveldb
chains/
chain/
blockchain/
blockchains/

20
Gopkg.lock generated
View file

@ -55,6 +55,12 @@
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
version = "v1.0.0"
[[projects]]
name = "github.com/sirupsen/logrus"
packages = ["."]
revision = "c155da19408a8799da419ed3eeb0cb5db0ad5dbc"
version = "v1.0.5"
[[projects]]
name = "github.com/stretchr/testify"
packages = ["assert"]
@ -92,10 +98,20 @@
packages = [
"pbkdf2",
"ripemd160",
"scrypt"
"scrypt",
"ssh/terminal"
]
revision = "8c653846df49742c4c85ec37e5d9f8d3ba657895"
[[projects]]
branch = "master"
name = "golang.org/x/sys"
packages = [
"unix",
"windows"
]
revision = "c28acc882ebcbfbe8ce9f0f14b9ac26ee138dd51"
[[projects]]
name = "golang.org/x/text"
packages = [
@ -122,6 +138,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "53597073e919ad7bf52895a19f8b8526d12d666862fb1d36b4a9756e0499da5a"
inputs-digest = "333dfa54a358d83b266025eff7f15854652631d90e18e61fa75723c5a030778b"
solver-name = "gps-cdcl"
solver-version = 1

View file

@ -53,5 +53,5 @@
version = "1.2.1"
[[constraint]]
name = "github.com/go-kit/kit"
version = "0.6.0"
name = "github.com/sirupsen/logrus"
version = "1.0.5"

View file

@ -2,6 +2,7 @@ BRANCH = "master"
VERSION = $(shell cat ./VERSION)
SEEDS ?= "127.0.0.1:20333"
PORT ?= "3000"
DBFILE ?= "chain"
build:
@go build -o ./bin/neo-go ./cli/main.go
@ -19,7 +20,7 @@ push-tag:
git push origin ${BRANCH} --tags
run: build
./bin/neo-go node -seed ${SEEDS} -tcp ${PORT} --relay true
./bin/neo-go node -seed ${SEEDS} -tcp ${PORT} -dbfile ${DBFILE} --relay true
test:
@go test ./... -cover

View file

@ -1 +1 @@
0.28.0
0.29.0

View file

@ -1,9 +1,12 @@
package server
import (
"fmt"
"strings"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/network"
"github.com/CityOfZion/neo-go/pkg/util"
"github.com/urfave/cli"
)
@ -18,6 +21,7 @@ func NewCommand() cli.Command {
cli.IntFlag{Name: "rpc"},
cli.BoolFlag{Name: "relay, r"},
cli.StringFlag{Name: "seed"},
cli.StringFlag{Name: "dbfile"},
cli.BoolFlag{Name: "privnet, p"},
cli.BoolFlag{Name: "mainnet, m"},
cli.BoolFlag{Name: "testnet, t"},
@ -42,11 +46,41 @@ func startServer(ctx *cli.Context) error {
Relay: ctx.Bool("relay"),
}
s := network.NewServer(cfg)
chain, err := newBlockchain(net, ctx.String("dbfile"))
if err != nil {
err = fmt.Errorf("could not initialize blockhain: %s", err)
return cli.NewExitError(err, 1)
}
s := network.NewServer(cfg, chain)
s.Start()
return nil
}
func newBlockchain(net network.NetMode, path string) (*core.Blockchain, error) {
var startHash util.Uint256
if net == network.ModePrivNet {
startHash = core.GenesisHashPrivNet()
}
if net == network.ModeTestNet {
startHash = core.GenesisHashTestNet()
}
if net == network.ModeMainNet {
startHash = core.GenesisHashMainNet()
}
// Hardcoded for now.
store, err := core.NewLevelDBStore(path, nil)
if err != nil {
return nil, err
}
return core.NewBlockchain(
store,
startHash,
), nil
}
func parseSeeds(s string) []string {
if len(s) == 0 {
return nil

View file

@ -4,12 +4,12 @@ import (
"bytes"
"encoding/binary"
"fmt"
"os"
"sync/atomic"
"time"
"github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
log "github.com/sirupsen/logrus"
"github.com/syndtr/goleveldb/leveldb"
)
// tuning parameters
@ -19,13 +19,12 @@ const (
)
var (
genAmount = []int{8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
genAmount = []int{8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
persistInterval = 5 * time.Second
)
// Blockchain holds the chain.
type Blockchain struct {
logger log.Logger
// Any object that satisfies the BlockchainStorer interface.
Store
@ -53,17 +52,13 @@ type headersOpFunc func(headerList *HeaderHashList)
// NewBlockchain creates a new Blockchain object.
func NewBlockchain(s Store, startHash util.Uint256) *Blockchain {
logger := log.NewLogfmtLogger(os.Stderr)
logger = log.With(logger, "component", "blockchain")
bc := &Blockchain{
logger: logger,
Store: s,
headersOp: make(chan headersOpFunc),
headersOpDone: make(chan struct{}),
startHash: startHash,
blockCache: NewCache(),
verifyBlocks: true,
verifyBlocks: false,
}
go bc.run()
bc.init()
@ -77,22 +72,30 @@ func (bc *Blockchain) init() {
}
func (bc *Blockchain) run() {
headerList := NewHeaderHashList(bc.startHash)
var (
headerList = NewHeaderHashList(bc.startHash)
persistTimer = time.NewTimer(persistInterval)
)
for {
select {
case op := <-bc.headersOp:
op(headerList)
bc.headersOpDone <- struct{}{}
case <-persistTimer.C:
go bc.persist()
persistTimer.Reset(persistInterval)
}
}
}
// AddBlock processes the given block and will add it to the cache so it
// can be persisted.
func (bc *Blockchain) AddBlock(block *Block) error {
if !bc.blockCache.Has(block.Hash()) {
bc.blockCache.Add(block.Hash(), block)
}
headerLen := int(bc.HeaderHeight() + 1)
headerLen := bc.headerListLen()
if int(block.Index-1) >= headerLen {
return nil
}
@ -105,10 +108,12 @@ func (bc *Blockchain) AddBlock(block *Block) error {
return nil
}
// AddHeaders will process the given headers and add them to the
// HeaderHashList.
func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {
var (
start = time.Now()
batch = Batch{}
batch = new(leveldb.Batch)
)
bc.headersOp <- func(headerList *HeaderHashList) {
@ -132,16 +137,15 @@ func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {
}
}
// TODO: Implement caching strategy.
if len(batch) > 0 {
if batch.Len() > 0 {
if err = bc.writeBatch(batch); err != nil {
return
}
bc.logger.Log(
"msg", "done processing headers",
"index", headerList.Len()-1,
"took", time.Since(start).Seconds(),
)
log.WithFields(log.Fields{
"headerIndex": headerList.Len() - 1,
"blockHeight": bc.BlockHeight(),
"took": time.Since(start),
}).Debug("done processing headers")
}
}
<-bc.headersOpDone
@ -150,7 +154,7 @@ func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {
// processHeader processes the given header. Note that this is only thread safe
// if executed in headers operation.
func (bc *Blockchain) processHeader(h *Header, batch Batch, headerList *HeaderHashList) error {
func (bc *Blockchain) processHeader(h *Header, batch *leveldb.Batch, headerList *HeaderHashList) error {
headerList.Add(h.Hash())
buf := new(bytes.Buffer)
@ -159,7 +163,7 @@ func (bc *Blockchain) processHeader(h *Header, batch Batch, headerList *HeaderHa
return err
}
key := makeEntryPrefixInt(preIXHeaderHashList, int(bc.storedHeaderCount))
batch[&key] = buf.Bytes()
batch.Put(key, buf.Bytes())
bc.storedHeaderCount += headerBatchCount
buf.Reset()
}
@ -170,29 +174,40 @@ func (bc *Blockchain) processHeader(h *Header, batch Batch, headerList *HeaderHa
}
key := makeEntryPrefix(preDataBlock, h.Hash().BytesReverse())
batch[&key] = buf.Bytes()
batch.Put(key, buf.Bytes())
key = preSYSCurrentHeader.bytes()
batch[&key] = hashAndIndexToBytes(h.Hash(), h.Index)
batch.Put(key, hashAndIndexToBytes(h.Hash(), h.Index))
return nil
}
func (bc *Blockchain) persistBlock(block *Block) error {
bc.blockHeight = block.Index
batch := new(leveldb.Batch)
// Store the block.
key := preSYSCurrentBlock.bytes()
batch.Put(key, hashAndIndexToBytes(block.Hash(), block.Index))
if err := bc.Store.writeBatch(batch); err != nil {
return err
}
atomic.AddUint32(&bc.blockHeight, 1)
return nil
}
func (bc *Blockchain) persist() (err error) {
var (
start = time.Now()
persisted = 0
lenCache = bc.blockCache.Len()
)
for lenCache > persisted {
if bc.HeaderHeight()+1 <= bc.BlockHeight() {
break
}
bc.headersOp <- func(headerList *HeaderHashList) {
bc.headersOp <- func(headerList *HeaderHashList) {
for i := 0; i < lenCache; i++ {
if uint32(headerList.Len()) <= bc.BlockHeight() {
return
}
hash := headerList.Get(int(bc.BlockHeight() + 1))
if block, ok := bc.blockCache.GetBlock(hash); ok {
if err = bc.persistBlock(block); err != nil {
@ -200,18 +215,47 @@ func (bc *Blockchain) persist() (err error) {
}
bc.blockCache.Delete(hash)
persisted++
} else {
bc.logger.Log(
"msg", "block not found in cache",
"hash", block.Hash(),
)
}
}
<-bc.headersOpDone
}
<-bc.headersOpDone
if persisted > 0 {
log.WithFields(log.Fields{
"persisted": persisted,
"blockHeight": bc.BlockHeight(),
"took": time.Since(start),
}).Info("blockchain persist completed")
}
return
}
func (bc *Blockchain) headerListLen() (n int) {
bc.headersOp <- func(headerList *HeaderHashList) {
n = headerList.Len()
}
<-bc.headersOpDone
return
}
// GetBlock returns a Block by the given hash.
func (bc *Blockchain) GetBlock(hash util.Uint256) (*Block, error) {
return nil, nil
}
// HasBlock return true if the blockchain contains he given
// transaction hash.
func (bc *Blockchain) HasTransaction(hash util.Uint256) bool {
return false
}
// HasBlock return true if the blockchain contains the given
// block hash.
func (bc *Blockchain) HasBlock(hash util.Uint256) bool {
return false
}
// CurrentBlockHash returns the heighest processed block hash.
func (bc *Blockchain) CurrentBlockHash() (hash util.Uint256) {
bc.headersOp <- func(headerList *HeaderHashList) {
@ -230,18 +274,24 @@ func (bc *Blockchain) CurrentHeaderHash() (hash util.Uint256) {
return
}
// GetHeaderHash return the hash from the headerList by its
// height/index.
func (bc *Blockchain) GetHeaderHash(i int) (hash util.Uint256) {
bc.headersOp <- func(headerList *HeaderHashList) {
hash = headerList.Get(i)
}
<-bc.headersOpDone
return
}
// BlockHeight returns the height/index of the highest block.
func (bc *Blockchain) BlockHeight() uint32 {
return atomic.LoadUint32(&bc.blockHeight)
}
// HeaderHeight returns the index/height of the highest header.
func (bc *Blockchain) HeaderHeight() (n uint32) {
bc.headersOp <- func(headerList *HeaderHashList) {
n = uint32(headerList.Len() - 1)
}
<-bc.headersOpDone
return
func (bc *Blockchain) HeaderHeight() uint32 {
return uint32(bc.headerListLen() - 1)
}
func hashAndIndexToBytes(h util.Uint256, index uint32) []byte {

View file

@ -76,6 +76,5 @@ func TestAddBlock(t *testing.T) {
func newTestBC() *Blockchain {
startHash, _ := util.Uint256DecodeString("a")
bc := NewBlockchain(NewMemoryStore(), startHash)
bc.verifyBlocks = false
return bc
}

17
pkg/core/blockchainer.go Normal file
View file

@ -0,0 +1,17 @@
package core
import "github.com/CityOfZion/neo-go/pkg/util"
// Blockchainer is an interface that abstract the implementation
// of the blockchain.
type Blockchainer interface {
AddHeaders(...*Header) error
AddBlock(*Block) error
BlockHeight() uint32
HeaderHeight() uint32
GetHeaderHash(int) util.Uint256
CurrentHeaderHash() util.Uint256
CurrentBlockHash() util.Uint256
HasBlock(util.Uint256) bool
HasTransaction(util.Uint256) bool
}

View file

@ -7,7 +7,7 @@ import (
)
// Cache is data structure with fixed type key of Uint256, but has a
// generic value. Used for block and header cash types.
// generic value. Used for block, tx and header cache types.
type Cache struct {
lock sync.RWMutex
m map[util.Uint256]interface{}

View file

@ -8,6 +8,8 @@ import (
)
// A HeaderHashList represents a list of header hashes.
// This datastructure in not routine safe and should be
// used under some kind of protection against race conditions.
type HeaderHashList struct {
hashes []util.Uint256
}
@ -31,7 +33,7 @@ func (l *HeaderHashList) Len() int {
// Get returns the hash by the given index.
func (l *HeaderHashList) Get(i int) util.Uint256 {
if l.Len() < i {
if l.Len() <= i {
return util.Uint256{}
}
return l.hashes[i]

View file

@ -2,24 +2,40 @@ package core
import (
"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/opt"
)
// LevelDBStore is the official storage implementation for storing and retreiving
// the blockchain.
// blockchain data.
type LevelDBStore struct {
db *leveldb.DB
db *leveldb.DB
path string
}
// Write implements the Store interface.
// NewLevelDBStore return a new LevelDBStore object that will
// initialize the database found at the given path.
func NewLevelDBStore(path string, opts *opt.Options) (*LevelDBStore, error) {
db, err := leveldb.OpenFile(path, opts)
if err != nil {
return nil, err
}
return &LevelDBStore{
path: path,
db: db,
}, nil
}
// write implements the Store interface.
func (s *LevelDBStore) write(key, value []byte) error {
return s.db.Put(key, value, nil)
}
// WriteBatch implements the Store interface.
func (s *LevelDBStore) writeBatch(batch Batch) error {
b := new(leveldb.Batch)
for k, v := range batch {
b.Put(*k, v)
}
return s.db.Write(b, nil)
//get implements the Store interface.
func (s *LevelDBStore) get(key []byte) ([]byte, error) {
return s.db.Get(key, nil)
}
// writeBatch implements the Store interface.
func (s *LevelDBStore) writeBatch(batch *leveldb.Batch) error {
return s.db.Write(batch, nil)
}

31
pkg/core/leveldb_test.go Normal file
View file

@ -0,0 +1,31 @@
package core
import (
"os"
"testing"
"github.com/CityOfZion/neo-go/pkg/util"
"github.com/syndtr/goleveldb/leveldb/opt"
)
const (
path = "test_chain"
)
func TestPersistBlock(t *testing.T) {
}
func newBlockchain() *Blockchain {
startHash, _ := util.Uint256DecodeString("a")
opts := &opt.Options{}
store, _ := NewLevelDBStore(path, opts)
chain := NewBlockchain(
store,
startHash,
)
return chain
}
func tearDown() error {
return os.RemoveAll(path)
}

View file

@ -1,24 +1,27 @@
package core
import "github.com/syndtr/goleveldb/leveldb"
// MemoryStore is an in memory implementation of a BlockChainStorer
// that should only be used for testing.
type MemoryStore struct {
}
type MemoryStore struct{}
// NewMemoryStore returns a pointer to a MemoryStore object.
func NewMemoryStore() *MemoryStore {
return &MemoryStore{}
}
// get implementes the BlockchainStorer interface.
func (m *MemoryStore) get(key []byte) ([]byte, error) {
return nil, nil
}
// write implementes the BlockchainStorer interface.
func (m *MemoryStore) write(key, value []byte) error {
return nil
}
func (m *MemoryStore) writeBatch(batch Batch) error {
for k, v := range batch {
if err := m.write(*k, v); err != nil {
return err
}
}
// writeBatch implementes the BlockchainStorer interface.
func (m *MemoryStore) writeBatch(batch *leveldb.Batch) error {
return nil
}

View file

@ -0,0 +1,7 @@
package core
var (
rawBlock0 = "000000000000000000000000000000000000000000000000000000000000000000000000f41bc036e39b0d6b0579c851c6fde83af802fa4e57bec0bc3365eae3abf43f8065fc8857000000001dac2b7c0000000059e75d652b5d3827bf04c165bbe9ef95cca4bf55010001510400001dac2b7c00000000400000455b7b226c616e67223a227a682d434e222c226e616d65223a22e5b08fe89a81e882a1227d2c7b226c616e67223a22656e222c226e616d65223a22416e745368617265227d5d0000c16ff28623000000da1745e9b549bd0bfa1a569971c77eba30cd5a4b00000000400001445b7b226c616e67223a227a682d434e222c226e616d65223a22e5b08fe89a81e5b881227d2c7b226c616e67223a22656e222c226e616d65223a22416e74436f696e227d5d0000c16ff286230008009f7fd096d37ed2c0e3f7f0cfc924beef4ffceb680000000001000000019b7cffdaa674beae0f930ebe6085af9093e5fe56b34a5c220ccdcf6efc336fc50000c16ff28623005fa99d93303775fe50ca119c327759313eccfa1c01000151"
rawBlock1 = "00000000bf4421c88776c53b43ce1dc45463bfd2028e322fdfb60064be150ed3e36125d418f98ec3ed2c2d1c9427385e7b85d0d1a366e29c4e399693a59718380f8bbad6d6d90358010000004490d0bb7170726c59e75d652b5d3827bf04c165bbe9ef95cca4bf5501fd4501404edf5005771de04619235d5a4c7a9a11bb78e008541f1da7725f654c33380a3c87e2959a025da706d7255cb3a3fa07ebe9c6559d0d9e6213c68049168eb1056f4038a338f879930c8adc168983f60aae6f8542365d844f004976346b70fb0dd31aa1dbd4abd81e4a4aeef9941ecd4e2dd2c1a5b05e1cc74454d0403edaee6d7a4d4099d33c0b889bf6f3e6d87ab1b11140282e9a3265b0b9b918d6020b2c62d5a040c7e0c2c7c1dae3af9b19b178c71552ebd0b596e401c175067c70ea75717c8c00404e0ebd369e81093866fe29406dbf6b402c003774541799d08bf9bb0fc6070ec0f6bad908ab95f05fa64e682b485800b3c12102a8596e6c715ec76f4564d5eff34070e0521979fcd2cbbfa1456d97cc18d9b4a6ad87a97a2a0bcdedbf71b6c9676c645886056821b6f3fec8694894c66f41b762bc4e29e46ad15aee47f05d27d822f1552102486fd15702c4490a26703112a5cc1d0923fd697a33406bd5a1c00e0013b09a7021024c7b7fb6c310fccf1ba33b082519d82964ea93868d676662d4a59ad548df0e7d2102aaec38470f6aad0042c6e877cfd8087d2676b0f516fddd362801b9bd3936399e2103b209fd4f53a7170ea4444e0cb0a6bb6a53c2bd016926989cf85f9b0fba17a70c2103b8d9d5771d8f513aa0869b9cc8d50986403b78c6da36890638c3d46a5adce04a2102ca0e27697b9c248f6f16e085fd0061e26f44da85b58ee835c110caa5ec3ba5542102df48f60e8f3e01c48ff40b9b7f1310d7a8b2a193188befe1c2e3df740e89509357ae0100004490d0bb00000000"
)

View file

@ -3,6 +3,8 @@ package core
import (
"bytes"
"encoding/binary"
"github.com/syndtr/goleveldb/leveldb"
)
type dataEntry uint8
@ -45,10 +47,7 @@ func makeEntryPrefix(e dataEntry, b []byte) []byte {
// Store is anything that can persist and retrieve the blockchain.
type Store interface {
get(k []byte) ([]byte, error)
write(k, v []byte) error
writeBatch(Batch) error
writeBatch(batch *leveldb.Batch) error
}
// Batch is a data type used to store data for later batch operations
// that can be used by any Store interface implementation.
type Batch map[*[]byte][]byte

View file

@ -10,8 +10,8 @@ const (
ECDH03 AttrUsage = 0x03
Script AttrUsage = 0x20
Vote AttrUsage = 0x30
CertUrl AttrUsage = 0x80
DescriptionUrl AttrUsage = 0x81
CertURL AttrUsage = 0x80
DescriptionURL AttrUsage = 0x81
Description AttrUsage = 0x90
Hash1 AttrUsage = 0xa1
@ -45,5 +45,5 @@ const (
Remark12 AttrUsage = 0xfc
Remark13 AttrUsage = 0xfd
Remark14 AttrUsage = 0xfe
Remark15 AttrUsage = 0xf
Remark15 AttrUsage = 0xff
)

View file

@ -34,7 +34,7 @@ func (attr *Attribute) DecodeBinary(r io.Reader) error {
attr.Data = make([]byte, 20)
return binary.Read(r, binary.LittleEndian, attr.Data)
}
if attr.Usage == DescriptionUrl {
if attr.Usage == DescriptionURL {
attr.Data = make([]byte, 1)
return binary.Read(r, binary.LittleEndian, attr.Data)
}
@ -63,7 +63,7 @@ func (attr *Attribute) EncodeBinary(w io.Writer) error {
if attr.Usage == Script {
return binary.Write(w, binary.LittleEndian, attr.Data)
}
if attr.Usage == DescriptionUrl {
if attr.Usage == DescriptionURL {
if err := util.WriteVarUint(w, uint64(len(attr.Data))); err != nil {
return err
}

View file

@ -0,0 +1,19 @@
package transaction
import (
"io"
)
// ContractTX represents a contract transaction.
// This TX has not special attributes.
type ContractTX struct{}
// DecodeBinary implements the Payload interface.
func (tx *ContractTX) DecodeBinary(r io.Reader) error {
return nil
}
// EncodeBinary implements the Payload interface.
func (tx *ContractTX) EncodeBinary(w io.Writer) error {
return nil
}

View file

@ -0,0 +1,8 @@
package transaction
var (
// https://neotracker.io/tx/2c6a45547b3898318e400e541628990a07acb00f3b9a15a8e966ae49525304da
rawClaimTX = "020004bc67ba325d6412ff4c55b10f7e9afb54bbb2228d201b37363c3d697ac7c198f70300591cd454d7318d2087c0196abfbbd1573230380672f0f0cd004dcb4857e58cbd010031bcfbed573f5318437e95edd603922a4455ff3326a979fdd1c149a84c4cb0290000b51eb6159c58cac4fe23d90e292ad2bcb7002b0da2c474e81e1889c0649d2c490000000001e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c603b555f00000000005d9de59d99c0d1f6ed1496444473f4a0b538302f014140456349cec43053009accdb7781b0799c6b591c812768804ab0a0b56b5eae7a97694227fcd33e70899c075848b2cee8fae733faac6865b484d3f7df8949e2aadb232103945fae1ed3c31d778f149192b76734fcc951b400ba3598faa81ff92ebe477eacac"
// https://neotracker.io/tx/fe4b3af60677204c57e573a57bdc97bc5059b05ad85b1474f84431f88d910f64
rawInvocationTX = "d101590400b33f7114839c33710da24cf8e7d536b8d244f3991cf565c8146063795d3b9b3cd55aef026eae992b91063db0db53c1087472616e7366657267c5cc1cb5392019e2cc4e6d6b5ea54c8d4b6d11acf166cb072961424c54f6000000000000000001206063795d3b9b3cd55aef026eae992b91063db0db0000014140c6a131c55ca38995402dff8e92ac55d89cbed4b98dfebbcb01acbc01bd78fa2ce2061be921b8999a9ab79c2958875bccfafe7ce1bbbaf1f56580815ea3a4feed232102d41ddce2c97be4c9aa571b8a32cbc305aa29afffbcae71b0ef568db0e93929aaac"
)

View file

@ -1,6 +1,8 @@
package transaction
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"io"
@ -17,7 +19,7 @@ type Transaction struct {
// Data specific to the type of the transaction.
// This is always a pointer to a <Type>Transaction.
Data interface{}
Data TXer
// Transaction attributes.
Attributes []*Attribute
@ -32,6 +34,14 @@ type Transaction struct {
// Scripts exist out of the verification script
// and invocation script.
Scripts []*Witness
// hash of the transaction
hash util.Uint256
}
// Hash return the hash of the transaction.
func (t *Transaction) Hash() util.Uint256 {
return t.hash
}
// AddOutput adds the given output to the transaction outputs.
@ -92,6 +102,14 @@ func (t *Transaction) DecodeBinary(r io.Reader) error {
}
}
// Create the hash of the transaction at decode, so we dont need
// to do it anymore.
hash, err := t.createHash()
if err != nil {
return err
}
t.hash = hash
return nil
}
@ -106,19 +124,41 @@ func (t *Transaction) decodeData(r io.Reader) error {
case ClaimType:
t.Data = &ClaimTX{}
return t.Data.(*ClaimTX).DecodeBinary(r)
case ContractType:
t.Data = &ContractTX{}
return t.Data.(*ContractTX).DecodeBinary(r)
}
return nil
}
// EncodeBinary implements the payload interface.
func (t *Transaction) EncodeBinary(w io.Writer) error {
if err := t.EncodeBinaryUnsigned(w); err != nil {
return err
}
if err := util.WriteVarUint(w, uint64(len(t.Scripts))); err != nil {
return err
}
for _, s := range t.Scripts {
if err := s.EncodeBinary(w); err != nil {
return err
}
}
return nil
}
// EncodeBinaryUnsigned will only encode the fields that are not used for
// signing the transaction, which are all fields except the scripts.
func (t *Transaction) EncodeBinaryUnsigned(w io.Writer) error {
if err := binary.Write(w, binary.LittleEndian, t.Type); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, t.Version); err != nil {
return err
}
if err := t.encodeData(w); err != nil {
// Underlying TXer.
if err := t.Data.EncodeBinary(w); err != nil {
return err
}
@ -151,28 +191,19 @@ func (t *Transaction) EncodeBinary(w io.Writer) error {
return err
}
}
// Scripts
if err := util.WriteVarUint(w, uint64(len(t.Scripts))); err != nil {
return err
}
for _, s := range t.Scripts {
if err := s.EncodeBinary(w); err != nil {
return err
}
}
return nil
}
func (t *Transaction) encodeData(w io.Writer) error {
switch t.Type {
case InvocationType:
return t.Data.(*InvocationTX).EncodeBinary(w)
case MinerType:
return t.Data.(*MinerTX).EncodeBinary(w)
case ClaimType:
return t.Data.(*ClaimTX).EncodeBinary(w)
func (t *Transaction) createHash() (hash util.Uint256, err error) {
buf := new(bytes.Buffer)
if err = t.EncodeBinaryUnsigned(buf); err != nil {
return
}
return nil
sha := sha256.New()
sha.Write(buf.Bytes())
b := sha.Sum(nil)
sha.Reset()
sha.Write(b)
b = sha.Sum(nil)
return util.Uint256DecodeBytes(util.ArrayReverse(b))
}

View file

@ -9,11 +9,8 @@ import (
"github.com/stretchr/testify/assert"
)
// Source of this TX: https://neotracker.io/tx/2c6a45547b3898318e400e541628990a07acb00f3b9a15a8e966ae49525304da
var rawTXClaim = "020004bc67ba325d6412ff4c55b10f7e9afb54bbb2228d201b37363c3d697ac7c198f70300591cd454d7318d2087c0196abfbbd1573230380672f0f0cd004dcb4857e58cbd010031bcfbed573f5318437e95edd603922a4455ff3326a979fdd1c149a84c4cb0290000b51eb6159c58cac4fe23d90e292ad2bcb7002b0da2c474e81e1889c0649d2c490000000001e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c603b555f00000000005d9de59d99c0d1f6ed1496444473f4a0b538302f014140456349cec43053009accdb7781b0799c6b591c812768804ab0a0b56b5eae7a97694227fcd33e70899c075848b2cee8fae733faac6865b484d3f7df8949e2aadb232103945fae1ed3c31d778f149192b76734fcc951b400ba3598faa81ff92ebe477eacac"
func TestDecodeEncodeClaimTX(t *testing.T) {
b, err := hex.DecodeString(rawTXClaim)
b, err := hex.DecodeString(rawClaimTX)
if err != nil {
t.Fatal(err)
}
@ -41,14 +38,14 @@ func TestDecodeEncodeClaimTX(t *testing.T) {
if err := tx.EncodeBinary(buf); err != nil {
t.Fatal(err)
}
assert.Equal(t, rawTXClaim, hex.EncodeToString(buf.Bytes()))
assert.Equal(t, rawClaimTX, hex.EncodeToString(buf.Bytes()))
hash := "2c6a45547b3898318e400e541628990a07acb00f3b9a15a8e966ae49525304da"
assert.Equal(t, hash, tx.hash.String())
}
// Source of this TX: https://neotracker.io/tx/fe4b3af60677204c57e573a57bdc97bc5059b05ad85b1474f84431f88d910f64
var rawTXInvocation = "d101590400b33f7114839c33710da24cf8e7d536b8d244f3991cf565c8146063795d3b9b3cd55aef026eae992b91063db0db53c1087472616e7366657267c5cc1cb5392019e2cc4e6d6b5ea54c8d4b6d11acf166cb072961424c54f6000000000000000001206063795d3b9b3cd55aef026eae992b91063db0db0000014140c6a131c55ca38995402dff8e92ac55d89cbed4b98dfebbcb01acbc01bd78fa2ce2061be921b8999a9ab79c2958875bccfafe7ce1bbbaf1f56580815ea3a4feed232102d41ddce2c97be4c9aa571b8a32cbc305aa29afffbcae71b0ef568db0e93929aaac"
func TestDecodeEncodeInvocationTX(t *testing.T) {
b, err := hex.DecodeString(rawTXInvocation)
b, err := hex.DecodeString(rawInvocationTX)
if err != nil {
t.Fatal(err)
}
@ -77,5 +74,5 @@ func TestDecodeEncodeInvocationTX(t *testing.T) {
if err := tx.EncodeBinary(buf); err != nil {
t.Fatal(err)
}
assert.Equal(t, rawTXInvocation, hex.EncodeToString(buf.Bytes()))
assert.Equal(t, rawInvocationTX, hex.EncodeToString(buf.Bytes()))
}

View file

@ -0,0 +1,10 @@
package transaction
import "io"
//TXer is interface that can act as the underlying data of
// a transaction.
type TXer interface {
DecodeBinary(io.Reader) error
EncodeBinary(io.Writer) error
}

113
pkg/network/discovery.go Normal file
View file

@ -0,0 +1,113 @@
package network
import (
"time"
)
const (
maxPoolSize = 200
)
// Discoverer is an interface that is responsible for maintaining
// a healty connection pool.
type Discoverer interface {
BackFill(...string)
PoolCount() int
RequestRemote(int)
}
// DefaultDiscovery
type DefaultDiscovery struct {
transport Transporter
dialTimeout time.Duration
addrs map[string]bool
badAddrs map[string]bool
requestCh chan int
backFill chan string
pool chan string
}
// NewDefaultDiscovery returns a new DefaultDiscovery.
func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
d := &DefaultDiscovery{
transport: ts,
dialTimeout: dt,
addrs: make(map[string]bool),
badAddrs: make(map[string]bool),
requestCh: make(chan int),
backFill: make(chan string),
pool: make(chan string, maxPoolSize),
}
go d.run()
return d
}
// BackFill implements the Discoverer interface and will backfill the
// the pool with the given addresses.
func (d *DefaultDiscovery) BackFill(addrs ...string) {
if len(d.pool) == maxPoolSize {
return
}
for _, addr := range addrs {
d.backFill <- addr
}
}
// PoolCount returns the number of available node addresses.
func (d *DefaultDiscovery) PoolCount() int {
return len(d.pool)
}
// Request will try to establish a connection with n nodes.
func (d *DefaultDiscovery) RequestRemote(n int) {
d.requestCh <- n
}
func (d *DefaultDiscovery) work(addrCh, badAddrCh chan string) {
for {
addr := <-addrCh
if err := d.transport.Dial(addr, d.dialTimeout); err != nil {
badAddrCh <- addr
}
}
}
func (d *DefaultDiscovery) next() string {
return <-d.pool
}
func (d *DefaultDiscovery) run() {
var (
maxWorkers = 5
badAddrCh = make(chan string)
workCh = make(chan string)
)
for i := 0; i < maxWorkers; i++ {
go d.work(workCh, badAddrCh)
}
for {
select {
case addr := <-d.backFill:
if _, ok := d.badAddrs[addr]; ok {
break
}
if _, ok := d.addrs[addr]; !ok {
d.addrs[addr] = true
d.pool <- addr
}
case n := <-d.requestCh:
go func() {
for i := 0; i < n; i++ {
workCh <- d.next()
}
}()
case addr := <-badAddrCh:
d.badAddrs[addr] = true
go func() {
workCh <- d.next()
}()
}
}
}

106
pkg/network/helper_test.go Normal file
View file

@ -0,0 +1,106 @@
package network
import (
"testing"
"time"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
)
type testChain struct{}
func (chain testChain) AddHeaders(...*core.Header) error {
return nil
}
func (chain testChain) AddBlock(*core.Block) error {
return nil
}
func (chain testChain) BlockHeight() uint32 {
return 0
}
func (chain testChain) HeaderHeight() uint32 {
return 0
}
func (chain testChain) GetHeaderHash(int) util.Uint256 {
return util.Uint256{}
}
func (chain testChain) CurrentHeaderHash() util.Uint256 {
return util.Uint256{}
}
func (chain testChain) CurrentBlockHash() util.Uint256 {
return util.Uint256{}
}
func (chain testChain) HasBlock(util.Uint256) bool {
return false
}
func (chain testChain) HasTransaction(util.Uint256) bool {
return false
}
type testDiscovery struct{}
func (d testDiscovery) BackFill(addrs ...string) {}
func (d testDiscovery) PoolCount() int { return 0 }
func (d testDiscovery) RequestRemote(n int) {}
type localTransport struct{}
func (t localTransport) Consumer() <-chan protoTuple {
ch := make(chan protoTuple)
return ch
}
func (t localTransport) Dial(addr string, timeout time.Duration) error {
return nil
}
func (t localTransport) Accept() {}
func (t localTransport) Proto() string { return "local" }
func (t localTransport) Close() {}
var defaultMessageHandler = func(t *testing.T, msg *Message) {}
type localPeer struct {
endpoint util.Endpoint
version *payload.Version
t *testing.T
messageHandler func(t *testing.T, msg *Message)
}
func newLocalPeer(t *testing.T) *localPeer {
return &localPeer{
t: t,
endpoint: util.NewEndpoint("0.0.0.0:0"),
messageHandler: defaultMessageHandler,
}
}
func (p *localPeer) Endpoint() util.Endpoint {
return p.endpoint
}
func (p *localPeer) Disconnect(err error) {}
func (p *localPeer) Send(msg *Message) {
p.messageHandler(p.t, msg)
}
func (p *localPeer) Done() chan error {
done := make(chan error)
return done
}
func (p *localPeer) Version() *payload.Version {
return p.version
}
func newTestServer() *Server {
return &Server{
Config: Config{},
chain: testChain{},
transport: localTransport{},
discovery: testDiscovery{},
id: util.RandUint32(1000000, 9999999),
quit: make(chan struct{}),
register: make(chan Peer),
unregister: make(chan peerDrop),
peers: make(map[Peer]bool),
}
}

View file

@ -36,7 +36,7 @@ func (n NetMode) String() string {
case ModeMainNet:
return "mainnet"
default:
return ""
return "net unknown"
}
}
@ -49,15 +49,20 @@ const (
// Message is the complete message send between nodes.
type Message struct {
// NetMode of the node that sends this message.
Magic NetMode
// Command is utf8 code, of which the length is 12 bytes,
// the extra part is filled with 0.
Command [cmdSize]byte
// Length of the payload
Length uint32
// Checksum is the first 4 bytes of the value that two times SHA256
// hash of the payload
Checksum uint32
// Payload send with the message.
Payload payload.Payload
}
@ -65,7 +70,7 @@ type Message struct {
// CommandType represents the type of a message command.
type CommandType string
// valid commands used to send between nodes.
// Valid protocol commands used to send between nodes.
const (
CMDVersion CommandType = "version"
CMDVerack CommandType = "verack"
@ -144,31 +149,22 @@ func (m *Message) CommandType() CommandType {
// decode a Message from the given reader.
func (m *Message) decode(r io.Reader) error {
err := binary.Read(r, binary.LittleEndian, &m.Magic)
if err != nil {
if err := binary.Read(r, binary.LittleEndian, &m.Magic); err != nil {
return err
}
err = binary.Read(r, binary.LittleEndian, &m.Command)
if err != nil {
if err := binary.Read(r, binary.LittleEndian, &m.Command); err != nil {
return err
}
err = binary.Read(r, binary.LittleEndian, &m.Length)
if err != nil {
if err := binary.Read(r, binary.LittleEndian, &m.Length); err != nil {
return err
}
err = binary.Read(r, binary.LittleEndian, &m.Checksum)
if err != nil {
if err := binary.Read(r, binary.LittleEndian, &m.Checksum); err != nil {
return err
}
// return if their is no payload.
if m.Length == 0 {
return nil
}
return m.decodePayload(r)
}
@ -188,42 +184,41 @@ func (m *Message) decodePayload(r io.Reader) error {
return errChecksumMismatch
}
r = buf
var p payload.Payload
switch m.CommandType() {
case CMDVersion:
p = &payload.Version{}
if err := p.DecodeBinary(r); err != nil {
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDInv:
p = &payload.Inventory{}
if err := p.DecodeBinary(r); err != nil {
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDAddr:
p = &payload.AddressList{}
if err := p.DecodeBinary(r); err != nil {
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDBlock:
p = &core.Block{}
if err := p.DecodeBinary(r); err != nil {
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDGetHeaders:
p = &payload.GetBlocks{}
if err := p.DecodeBinary(r); err != nil {
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDHeaders:
p = &payload.Headers{}
if err := p.DecodeBinary(r); err != nil {
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDTX:
p = &transaction.Transaction{}
if err := p.DecodeBinary(r); err != nil {
if err := p.DecodeBinary(buf); err != nil {
return err
}
}
@ -247,11 +242,9 @@ func (m *Message) encode(w io.Writer) error {
if err := binary.Write(w, binary.LittleEndian, m.Checksum); err != nil {
return err
}
if m.Payload != nil {
return m.Payload.EncodeBinary(w)
}
return nil
}

View file

@ -1,61 +1 @@
package network
import (
"bytes"
"testing"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/stretchr/testify/assert"
)
func TestMessageEncodeDecode(t *testing.T) {
m := NewMessage(ModeTestNet, CMDVersion, nil)
buf := &bytes.Buffer{}
if err := m.encode(buf); err != nil {
t.Error(err)
}
assert.Equal(t, len(buf.Bytes()), minMessageSize)
md := &Message{}
if err := md.decode(buf); err != nil {
t.Error(err)
}
assert.Equal(t, m, md)
}
func TestMessageEncodeDecodeWithVersion(t *testing.T) {
var (
p = payload.NewVersion(12227, 2000, "/neo:2.6.0/", 0, true)
m = NewMessage(ModeTestNet, CMDVersion, p)
)
buf := new(bytes.Buffer)
if err := m.encode(buf); err != nil {
t.Error(err)
}
mDecode := &Message{}
if err := mDecode.decode(buf); err != nil {
t.Fatal(err)
}
assert.Equal(t, m, mDecode)
}
func TestMessageInvalidChecksum(t *testing.T) {
var (
p = payload.NewVersion(1111, 3000, "/NEO:2.6.0/", 0, true)
m = NewMessage(ModeTestNet, CMDVersion, p)
)
m.Checksum = 1337
buf := new(bytes.Buffer)
if err := m.encode(buf); err != nil {
t.Error(err)
}
md := &Message{}
if err := md.decode(buf); err == nil && err != errChecksumMismatch {
t.Fatalf("decode should fail with %s", errChecksumMismatch)
}
}

View file

@ -1,213 +0,0 @@
package network
import (
"errors"
"fmt"
"os"
"time"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
)
const (
protoVersion = 0
)
var protoTickInterval = 5 * time.Second
// Node represents the local node.
type Node struct {
// Config fields may not be modified while the server is running.
Config
logger log.Logger
server *Server
services uint64
bc *core.Blockchain
}
func newNode(s *Server, cfg Config) *Node {
var startHash util.Uint256
if cfg.Net == ModePrivNet {
startHash = core.GenesisHashPrivNet()
}
if cfg.Net == ModeTestNet {
startHash = core.GenesisHashTestNet()
}
if cfg.Net == ModeMainNet {
startHash = core.GenesisHashMainNet()
}
bc := core.NewBlockchain(
core.NewMemoryStore(),
startHash,
)
logger := log.NewLogfmtLogger(os.Stderr)
logger = log.With(logger, "component", "node")
n := &Node{
Config: cfg,
server: s,
bc: bc,
logger: logger,
}
return n
}
func (n *Node) version() *payload.Version {
return payload.NewVersion(n.server.id, n.ListenTCP, n.UserAgent, 1, n.Relay)
}
func (n *Node) startProtocol(p Peer) {
n.logger.Log(
"event", "start protocol",
"peer", p.Endpoint(),
"userAgent", string(p.Version().UserAgent),
)
defer func() {
n.logger.Log(
"msg", "protocol stopped",
"peer", p.Endpoint(),
)
}()
timer := time.NewTimer(protoTickInterval)
for {
<-timer.C
select {
case <-p.Done():
return
default:
// Try to sync with the peer if his block height is higher then ours.
if p.Version().StartHeight > n.bc.HeaderHeight() {
n.askMoreHeaders(p)
}
// Only ask for more peers if the server has the capacity for it.
if n.server.hasCapacity() {
msg := NewMessage(n.Net, CMDGetAddr, nil)
p.Send(msg)
}
timer.Reset(protoTickInterval)
}
}
}
// When a peer sends out his version we reply with verack after validating
// the version.
func (n *Node) handleVersionCmd(version *payload.Version, p Peer) error {
msg := NewMessage(n.Net, CMDVerack, nil)
p.Send(msg)
return nil
}
// handleInvCmd handles the forwarded inventory received from the peer.
// We will use the getdata message to get more details about the received
// inventory.
// note: if the server has Relay on false, inventory messages are not received.
func (n *Node) handleInvCmd(inv *payload.Inventory, p Peer) error {
if !inv.Type.Valid() {
return fmt.Errorf("invalid inventory type received: %s", inv.Type)
}
if len(inv.Hashes) == 0 {
return errors.New("inventory has no hashes")
}
payload := payload.NewInventory(inv.Type, inv.Hashes)
p.Send(NewMessage(n.Net, CMDGetData, payload))
return nil
}
// handleBlockCmd processes the received block received from its peer.
func (n *Node) handleBlockCmd(block *core.Block, peer Peer) error {
n.logger.Log(
"event", "block received",
"index", block.Index,
"hash", block.Hash(),
"tx", len(block.Transactions),
)
return n.bc.AddBlock(block)
}
// After a node sends out the getaddr message its receives a list of known peers
// in the network. handleAddrCmd processes that payload.
func (n *Node) handleAddrCmd(addressList *payload.AddressList, peer Peer) error {
addrs := make([]string, len(addressList.Addrs))
for i := 0; i < len(addrs); i++ {
addrs[i] = addressList.Addrs[i].Address.String()
}
n.server.connectToPeers(addrs...)
return nil
}
// The handleHeadersCmd will process the received headers from its peer.
// We call this in a routine cause we may block Peers Send() for to long.
func (n *Node) handleHeadersCmd(headers *payload.Headers, peer Peer) error {
go func(headers []*core.Header) {
if err := n.bc.AddHeaders(headers...); err != nil {
n.logger.Log("msg", "failed processing headers", "err", err)
return
}
// The peer will respond with a maximum of 2000 headers in one batch.
// We will ask one more batch here if needed. Eventually we will get synced
// due to the startProtocol routine that will ask headers every protoTick.
if n.bc.HeaderHeight() < peer.Version().StartHeight {
n.askMoreHeaders(peer)
}
}(headers.Hdrs)
return nil
}
// askMoreHeaders will send a getheaders message to the peer.
func (n *Node) askMoreHeaders(p Peer) {
start := []util.Uint256{n.bc.CurrentHeaderHash()}
payload := payload.NewGetBlocks(start, util.Uint256{})
p.Send(NewMessage(n.Net, CMDGetHeaders, payload))
}
// blockhain implements the Noder interface.
func (n *Node) blockchain() *core.Blockchain { return n.bc }
func (n *Node) handleProto(msg *Message, p Peer) error {
//n.logger.Log(
// "event", "message received",
// "from", p.Endpoint(),
// "msg", msg.CommandType(),
//)
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
return n.handleVersionCmd(version, p)
case CMDAddr:
addressList := msg.Payload.(*payload.AddressList)
return n.handleAddrCmd(addressList, p)
case CMDInv:
inventory := msg.Payload.(*payload.Inventory)
return n.handleInvCmd(inventory, p)
case CMDBlock:
block := msg.Payload.(*core.Block)
return n.handleBlockCmd(block, p)
case CMDHeaders:
headers := msg.Payload.(*payload.Headers)
return n.handleHeadersCmd(headers, p)
case CMDTX:
// tx := msg.Payload.(*transaction.Transaction)
//n.logger.Log("tx", fmt.Sprintf("%+v", tx))
return nil
case CMDVerack:
// Only start the protocol if we got the version and verack
// received.
if p.Version() != nil {
go n.startProtocol(p)
}
return nil
case CMDUnknown:
return errors.New("received non-protocol messgae")
}
return nil
}

View file

@ -1,7 +0,0 @@
package network
import "testing"
func TestHandleVersion(t *testing.T) {
}

View file

@ -12,7 +12,7 @@ import (
type AddressAndTime struct {
Timestamp uint32
Services uint64
Address util.Endpoint
Endpoint util.Endpoint
}
// NewAddressAndTime creates a new AddressAndTime object.
@ -20,7 +20,7 @@ func NewAddressAndTime(e util.Endpoint, t time.Time) *AddressAndTime {
return &AddressAndTime{
Timestamp: uint32(t.UTC().Unix()),
Services: 1,
Address: e,
Endpoint: e,
}
}
@ -32,10 +32,10 @@ func (p *AddressAndTime) DecodeBinary(r io.Reader) error {
if err := binary.Read(r, binary.LittleEndian, &p.Services); err != nil {
return err
}
if err := binary.Read(r, binary.BigEndian, &p.Address.IP); err != nil {
if err := binary.Read(r, binary.BigEndian, &p.Endpoint.IP); err != nil {
return err
}
return binary.Read(r, binary.BigEndian, &p.Address.Port)
return binary.Read(r, binary.BigEndian, &p.Endpoint.Port)
}
// EncodeBinary implements the Payload interface.
@ -46,10 +46,10 @@ func (p *AddressAndTime) EncodeBinary(w io.Writer) error {
if err := binary.Write(w, binary.LittleEndian, p.Services); err != nil {
return err
}
if err := binary.Write(w, binary.BigEndian, p.Address.IP); err != nil {
if err := binary.Write(w, binary.BigEndian, p.Endpoint.IP); err != nil {
return err
}
return binary.Write(w, binary.BigEndian, p.Address.Port)
return binary.Write(w, binary.BigEndian, p.Endpoint.Port)
}
// AddressList is a list with AddrAndTime.

View file

@ -31,14 +31,7 @@ func (p *GetBlocks) DecodeBinary(r io.Reader) error {
if err := binary.Read(r, binary.LittleEndian, &p.HashStart); err != nil {
return err
}
// If the reader returns EOF we know the hashStop is not encoded.
err := binary.Read(r, binary.LittleEndian, &p.HashStop)
if err == io.EOF {
return nil
}
return err
return binary.Read(r, binary.LittleEndian, &p.HashStop)
}
// EncodeBinary implements the payload interface.
@ -49,14 +42,7 @@ func (p *GetBlocks) EncodeBinary(w io.Writer) error {
if err := binary.Write(w, binary.LittleEndian, p.HashStart); err != nil {
return err
}
// Only write hashStop if its not filled with zero bytes.
var emtpy util.Uint256
if p.HashStop != emtpy {
return binary.Write(w, binary.LittleEndian, p.HashStop)
}
return nil
return binary.Write(w, binary.LittleEndian, p.HashStop)
}
// Size implements the payload interface.

View file

@ -3,10 +3,10 @@ package payload
import (
"bytes"
"crypto/sha256"
"reflect"
"testing"
"github.com/CityOfZion/neo-go/pkg/util"
"github.com/stretchr/testify/assert"
)
func TestGetBlockEncodeDecode(t *testing.T) {
@ -28,9 +28,7 @@ func TestGetBlockEncodeDecode(t *testing.T) {
t.Fatal(err)
}
if !reflect.DeepEqual(p, pDecode) {
t.Fatalf("expected to have equal block payload %v and %v", p, pDecode)
}
assert.Equal(t, p, pDecode)
}
func TestGetBlockEncodeDecodeWithHashStop(t *testing.T) {
@ -54,7 +52,5 @@ func TestGetBlockEncodeDecodeWithHashStop(t *testing.T) {
t.Fatal(err)
}
if !reflect.DeepEqual(p, pDecode) {
t.Fatalf("expected to have equal block payload %v and %v", p, pDecode)
}
assert.Equal(t, p, pDecode)
}

View file

@ -31,7 +31,9 @@ func (p *Headers) DecodeBinary(r io.Reader) error {
// EncodeBinary implements the Payload interface.
func (p *Headers) EncodeBinary(w io.Writer) error {
util.WriteVarUint(w, uint64(len(p.Hdrs)))
if err := util.WriteVarUint(w, uint64(len(p.Hdrs))); err != nil {
return err
}
for _, header := range p.Hdrs {
if err := header.EncodeBinary(w); err != nil {
return err

View file

@ -34,8 +34,8 @@ func (i InventoryType) Valid() bool {
// List of valid InventoryTypes.
const (
BlockType InventoryType = 0x01 // 1
TXType InventoryType = 0x02 // 2
TXType InventoryType = 0x01 // 1
BlockType InventoryType = 0x02 // 2
ConsensusType InventoryType = 0xe0 // 224
)
@ -43,7 +43,8 @@ const (
type Inventory struct {
// Type if the object hash.
Type InventoryType
// The hash of the object (uint256).
// A list of hashes.
Hashes []util.Uint256
}
@ -57,9 +58,11 @@ func NewInventory(typ InventoryType, hashes []util.Uint256) *Inventory {
// DecodeBinary implements the Payload interface.
func (p *Inventory) DecodeBinary(r io.Reader) error {
err := binary.Read(r, binary.LittleEndian, &p.Type)
listLen := util.ReadVarUint(r)
if err := binary.Read(r, binary.LittleEndian, &p.Type); err != nil {
return err
}
listLen := util.ReadVarUint(r)
p.Hashes = make([]util.Uint256, listLen)
for i := 0; i < int(listLen); i++ {
if err := binary.Read(r, binary.LittleEndian, &p.Hashes[i]); err != nil {
@ -67,25 +70,24 @@ func (p *Inventory) DecodeBinary(r io.Reader) error {
}
}
return err
return nil
}
// EncodeBinary implements the Payload interface.
func (p *Inventory) EncodeBinary(w io.Writer) error {
listLen := len(p.Hashes)
err := binary.Write(w, binary.LittleEndian, p.Type)
err = util.WriteVarUint(w, uint64(listLen))
if err := binary.Write(w, binary.LittleEndian, p.Type); err != nil {
return err
}
listLen := len(p.Hashes)
if err := util.WriteVarUint(w, uint64(listLen)); err != nil {
return err
}
for i := 0; i < len(p.Hashes); i++ {
if err := binary.Write(w, binary.LittleEndian, p.Hashes[i]); err != nil {
return err
}
}
return err
}
// Size implements the Payloader interface.
func (p *Inventory) Size() uint32 {
return 1 + 1 + 32 // ?
return nil
}

View file

@ -5,13 +5,10 @@ import (
"github.com/CityOfZion/neo-go/pkg/util"
)
// A Peer is the local representation of a remote peer.
// It's an interface that may be backed by any concrete
// transport.
type Peer interface {
Version() *payload.Version
Endpoint() util.Endpoint
Send(*Message)
Done() chan struct{}
Disconnect(err error)
Disconnect(error)
Send(msg *Message)
Done() chan error
Version() *payload.Version
}

View file

@ -1,22 +0,0 @@
package network
import (
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/network/payload"
)
// A ProtoHandler is an interface that abstract the implementation
// of the NEO protocol.
type ProtoHandler interface {
version() *payload.Version
handleProto(*Message, Peer) error
}
type protoHandleFunc func(*Message, Peer) error
// Noder is anything that implements the NEO protocol
// and can return the Blockchain object.
type Noder interface {
ProtoHandler
blockchain() *core.Blockchain
}

View file

@ -1,27 +1,35 @@
package network
import (
"errors"
"fmt"
"net"
"os"
"text/tabwriter"
"sync"
"time"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
log "github.com/sirupsen/logrus"
)
const (
// node version
version = "2.6.0"
// official ports according to the protocol.
portMainNet = 10333
portTestNet = 20333
maxPeers = 50
maxPeers = 50
minPeers = 5
maxBlockBatch = 200
minPoolCount = 30
)
var dialTimeout = 4 * time.Second
var (
protoTickInterval = 10 * time.Second
dialTimeout = 3 * time.Second
errPortMismatch = errors.New("port mismatch")
errIdenticalID = errors.New("identical node id")
errInvalidHandshake = errors.New("invalid handshake")
errInvalidNetwork = errors.New("invalid network")
errServerShutdown = errors.New("server shutdown")
errInvalidInvType = errors.New("invalid inventory type")
)
// Config holds the server configuration.
type Config struct {
@ -35,10 +43,7 @@ type Config struct {
// The listen address of the TCP server.
ListenTCP uint16
// The listen address of the RPC server.
ListenRPC uint16
// The network mode this server will operate on.
// The network mode the server will operate on.
// ModePrivNet docker private network.
// ModeTestNet NEO test network.
// ModeMainNet NEO main network.
@ -52,252 +57,315 @@ type Config struct {
// Maximum duration a single dial may take.
DialTimeout time.Duration
// The duration between protocol ticks with each connected peer.
// When this is 0, the default interval of 5 seconds will be used.
ProtoTickInterval time.Duration
// Level of the internal logger.
LogLevel log.Level
}
// Server manages all incoming peer connections.
type Server struct {
// Config fields may not be modified while the server is running.
Config
type (
// Server represents the local Node in the network. Its transport could
// be of any kind.
Server struct {
// Config holds the Server configuration.
Config
// Proto is just about anything that can handle the NEO protocol.
// In production enviroments the ProtoHandler is mostly the local node.
proto ProtoHandler
// id also known as the nonce of te server.
id uint32
// Unique id of this server.
id uint32
transport Transporter
discovery Discoverer
chain core.Blockchainer
logger log.Logger
listener net.Listener
lock sync.RWMutex
peers map[Peer]bool
register chan Peer
unregister chan peerDrop
register chan Peer
unregister chan peerDrop
quit chan struct{}
badAddrOp chan func(map[string]bool)
badAddrOpDone chan struct{}
peerOp chan func(map[Peer]bool)
peerOpDone chan struct{}
quit chan struct{}
}
type peerDrop struct {
p Peer
err error
}
// NewServer returns a new Server object created from the
// given config.
func NewServer(cfg Config) *Server {
if cfg.MaxPeers == 0 {
cfg.MaxPeers = maxPeers
proto <-chan protoTuple
}
if cfg.Net == 0 {
cfg.Net = ModeTestNet
protoTuple struct {
msg *Message
peer Peer
}
peerDrop struct {
peer Peer
reason error
}
)
// NewServer returns a new Server, initialized with the given configuration.
func NewServer(cfg Config, chain *core.Blockchain) *Server {
if cfg.ProtoTickInterval == 0 {
cfg.ProtoTickInterval = protoTickInterval
}
if cfg.DialTimeout == 0 {
cfg.DialTimeout = dialTimeout
}
logger := log.NewLogfmtLogger(os.Stderr)
logger = log.With(logger, "component", "server")
if cfg.MaxPeers == 0 {
cfg.MaxPeers = maxPeers
}
log.SetLevel(log.DebugLevel)
s := &Server{
Config: cfg,
logger: logger,
id: util.RandUint32(1000000, 9999999),
quit: make(chan struct{}, 1),
register: make(chan Peer),
unregister: make(chan peerDrop),
badAddrOp: make(chan func(map[string]bool)),
badAddrOpDone: make(chan struct{}),
peerOp: make(chan func(map[Peer]bool)),
peerOpDone: make(chan struct{}),
Config: cfg,
chain: chain,
id: util.RandUint32(1000000, 9999999),
quit: make(chan struct{}),
register: make(chan Peer),
unregister: make(chan peerDrop),
peers: make(map[Peer]bool),
}
s.proto = newNode(s, cfg)
s.transport = NewTCPTransport(s, fmt.Sprintf(":%d", cfg.ListenTCP))
s.proto = s.transport.Consumer()
s.discovery = NewDefaultDiscovery(
s.DialTimeout,
s.transport,
)
return s
}
func (s *Server) createListener() error {
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", s.ListenTCP))
if err != nil {
return err
}
s.listener = ln
return nil
}
func (s *Server) listenTCP() {
for {
conn, err := s.listener.Accept()
if err != nil {
s.logger.Log("msg", "conn read error", "err", err)
break
}
go s.setupPeerConn(conn)
}
s.Quit()
}
// setupPeerConn runs in its own routine for each connected Peer.
// and waits till the Peer.Run() returns.
func (s *Server) setupPeerConn(conn net.Conn) {
if !s.hasCapacity() {
s.logger.Log("msg", "server reached maximum capacity")
return
}
p := NewTCPPeer(conn, s.proto.handleProto)
s.register <- p
err := p.run()
s.unregister <- peerDrop{p, err}
}
func (s *Server) connectToPeers(addrs ...string) {
for _, addr := range addrs {
if s.hasCapacity() && s.canConnectWith(addr) {
go func(addr string) {
conn, err := net.DialTimeout("tcp", addr, s.DialTimeout)
if err != nil {
s.badAddrOp <- func(badAddrs map[string]bool) {
badAddrs[addr] = true
}
<-s.badAddrOpDone
return
}
go s.setupPeerConn(conn)
}(addr)
}
}
}
func (s *Server) canConnectWith(addr string) bool {
canConnect := true
s.peerOp <- func(peers map[Peer]bool) {
for peer := range peers {
if peer.Endpoint().String() == addr {
canConnect = false
break
}
}
}
<-s.peerOpDone
if !canConnect {
return false
}
s.badAddrOp <- func(badAddrs map[string]bool) {
_, ok := badAddrs[addr]
canConnect = !ok
}
<-s.badAddrOpDone
return canConnect
}
func (s *Server) hasCapacity() bool {
return s.PeerCount() != s.MaxPeers
}
func (s *Server) sendVersion(p Peer) {
p.Send(NewMessage(s.Net, CMDVersion, s.proto.version()))
// Start will start the server and its underlying transport.
func (s *Server) Start() {
go s.transport.Accept()
s.discovery.BackFill(s.Seeds...)
s.run()
}
func (s *Server) run() {
var (
peers = make(map[Peer]bool)
badAddrs = make(map[string]bool)
)
// Ask discovery to connect with remote nodes to fill up
// the server minimum peer slots.
s.discovery.RequestRemote(minPeers - s.PeerCount())
for {
select {
case op := <-s.badAddrOp:
op(badAddrs)
s.badAddrOpDone <- struct{}{}
case op := <-s.peerOp:
op(peers)
s.peerOpDone <- struct{}{}
case p := <-s.register:
peers[p] = true
// When a new peer connection is established, we send
// out our version immediately.
s.sendVersion(p)
s.logger.Log("event", "peer connected", "endpoint", p.Endpoint())
case drop := <-s.unregister:
delete(peers, drop.p)
s.logger.Log(
"event", "peer disconnected",
"endpoint", drop.p.Endpoint(),
"reason", drop.err,
"peerCount", len(peers),
)
if len(peers) == 0 {
s.logger.Log("fatal", "no more available peers")
return
case proto := <-s.proto:
if err := s.processProto(proto); err != nil {
proto.peer.Disconnect(err)
// verack and version implies that the protocol is
// not started and the only way to disconnect them
// from the server is to manually call unregister.
switch proto.msg.CommandType() {
case CMDVerack, CMDVersion:
go func() {
s.unregister <- peerDrop{proto.peer, err}
}()
}
}
case <-s.quit:
s.transport.Close()
for p, _ := range s.peers {
p.Disconnect(errServerShutdown)
}
return
case p := <-s.register:
// When a new peer is connected we send out our version immediately.
s.sendVersion(p)
s.peers[p] = true
log.WithFields(log.Fields{
"endpoint": p.Endpoint(),
}).Info("new peer connected")
case drop := <-s.unregister:
s.discovery.RequestRemote(1)
delete(s.peers, drop.peer)
log.WithFields(log.Fields{
"endpoint": drop.peer.Endpoint(),
"reason": drop.reason,
"peerCount": s.PeerCount(),
}).Warn("peer disconnected")
}
}
}
// PeerCount returns the number of current connected peers.
func (s *Server) PeerCount() (n int) {
s.peerOp <- func(peers map[Peer]bool) {
n = len(peers)
}
<-s.peerOpDone
return
func (s *Server) PeerCount() int {
s.lock.RLock()
defer s.lock.RUnlock()
return len(s.peers)
}
func (s *Server) Start() error {
fmt.Println(logo())
fmt.Println("")
s.printConfiguration()
// startProtocol starts a long running background loop that interacts
// every ProtoTickInterval with the peer.
func (s *Server) startProtocol(p Peer) {
log.WithFields(log.Fields{
"endpoint": p.Endpoint(),
"userAgent": string(p.Version().UserAgent),
"startHeight": p.Version().StartHeight,
"id": p.Version().Nonce,
}).Info("started protocol")
if err := s.createListener(); err != nil {
return err
s.requestHeaders(p)
s.requestPeerInfo(p)
timer := time.NewTimer(s.ProtoTickInterval)
for {
select {
case err := <-p.Done():
s.unregister <- peerDrop{p, err}
return
case <-timer.C:
// Try to sync in headers and block with the peer if his block height is higher then ours.
if p.Version().StartHeight > s.chain.BlockHeight() {
s.requestBlocks(p)
}
// If the discovery does not have a healthy address pool
// we will ask for a new batch of addresses.
if s.discovery.PoolCount() < minPoolCount {
s.requestPeerInfo(p)
}
timer.Reset(s.ProtoTickInterval)
}
}
}
// When a peer connects to the server, we will send our version immediately.
func (s *Server) sendVersion(p Peer) {
payload := payload.NewVersion(
s.id,
s.ListenTCP,
s.UserAgent,
s.chain.BlockHeight(),
s.Relay,
)
p.Send(NewMessage(s.Net, CMDVersion, payload))
}
// When a peer sends out his version we reply with verack after validating
// the version.
func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
if p.Endpoint().Port != version.Port {
return errPortMismatch
}
if s.id == version.Nonce {
return errIdenticalID
}
p.Send(NewMessage(s.Net, CMDVerack, nil))
return nil
}
// handleHeadersCmd will process the headers it received from its peer.
// if the headerHeight of the blockchain still smaller then the peer
// the server will request more headers.
// This method could best be called in a separate routine.
func (s *Server) handleHeadersCmd(p Peer, headers *payload.Headers) {
if err := s.chain.AddHeaders(headers.Hdrs...); err != nil {
log.Warnf("failed processing headers: %s", err)
return
}
// The peer will respond with a maximum of 2000 headers in one batch.
// We will ask one more batch here if needed. Eventually we will get synced
// due to the startProtocol routine that will ask headers every protoTick.
if s.chain.HeaderHeight() < p.Version().StartHeight {
s.requestHeaders(p)
}
}
// handleBlockCmd processes the received block received from its peer.
func (s *Server) handleBlockCmd(p Peer, block *core.Block) error {
return s.chain.AddBlock(block)
}
// handleInvCmd will process the received inventory.
func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
if !inv.Type.Valid() || len(inv.Hashes) == 0 {
return errInvalidInvType
}
payload := payload.NewInventory(inv.Type, inv.Hashes)
p.Send(NewMessage(s.Net, CMDGetData, payload))
return nil
}
func (s *Server) handleGetHeadersCmd(p Peer, getHeaders *payload.GetBlocks) error {
log.Info(getHeaders)
return nil
}
// requestHeaders will send a getheaders message to the peer.
// The peer will respond with headers op to a count of 2000.
func (s *Server) requestHeaders(p Peer) {
start := []util.Uint256{s.chain.CurrentHeaderHash()}
payload := payload.NewGetBlocks(start, util.Uint256{})
p.Send(NewMessage(s.Net, CMDGetHeaders, payload))
}
// requestPeerInfo will send a getaddr message to the peer
// which will respond with his known addresses in the network.
func (s *Server) requestPeerInfo(p Peer) {
p.Send(NewMessage(s.Net, CMDGetAddr, nil))
}
// requestBlocks will send a getdata message to the peer
// to sync up in blocks. A maximum of maxBlockBatch will
// send at once.
func (s *Server) requestBlocks(p Peer) {
var (
hashStart = s.chain.BlockHeight() + 1
headerHeight = s.chain.HeaderHeight()
hashes = []util.Uint256{}
)
for hashStart < headerHeight && len(hashes) < maxBlockBatch {
hash := s.chain.GetHeaderHash(int(hashStart))
hashes = append(hashes, hash)
hashStart++
}
if len(hashes) > 0 {
payload := payload.NewInventory(payload.BlockType, hashes)
p.Send(NewMessage(s.Net, CMDGetData, payload))
} else if s.chain.HeaderHeight() < p.Version().StartHeight {
s.requestHeaders(p)
}
}
// process the received protocol message.
func (s *Server) processProto(proto protoTuple) error {
var (
peer = proto.peer
msg = proto.msg
)
// Make sure both server and peer are operating on
// the same network.
if msg.Magic != s.Net {
return errInvalidNetwork
}
go s.run()
go s.listenTCP()
go s.connectToPeers(s.Seeds...)
select {}
}
func (s *Server) Quit() {
s.quit <- struct{}{}
}
func (s *Server) printState() {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0)
fmt.Fprintf(w, "connected peers:\t%d/%d\n", s.PeerCount(), s.MaxPeers)
w.Flush()
}
func (s *Server) printConfiguration() {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0)
fmt.Fprintf(w, "user agent:\t%s\n", s.UserAgent)
fmt.Fprintf(w, "id:\t%d\n", s.id)
fmt.Fprintf(w, "network:\t%s\n", s.Net)
fmt.Fprintf(w, "listen TCP:\t%d\n", s.ListenTCP)
fmt.Fprintf(w, "listen RPC:\t%d\n", s.ListenRPC)
fmt.Fprintf(w, "relay:\t%v\n", s.Relay)
fmt.Fprintf(w, "max peers:\t%d\n", s.MaxPeers)
chainer := s.proto.(Noder)
fmt.Fprintf(w, "current height:\t%d\n", chainer.blockchain().HeaderHeight())
fmt.Fprintln(w, "")
w.Flush()
}
func logo() string {
return `
_ ____________ __________
/ | / / ____/ __ \ / ____/ __ \
/ |/ / __/ / / / /_____/ / __/ / / /
/ /| / /___/ /_/ /_____/ /_/ / /_/ /
/_/ |_/_____/\____/ \____/\____/
`
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
return s.handleVersionCmd(peer, version)
case CMDHeaders:
headers := msg.Payload.(*payload.Headers)
go s.handleHeadersCmd(peer, headers)
case CMDInv:
inventory := msg.Payload.(*payload.Inventory)
return s.handleInvCmd(peer, inventory)
case CMDBlock:
block := msg.Payload.(*core.Block)
return s.handleBlockCmd(peer, block)
case CMDGetHeaders:
getHeaders := msg.Payload.(*payload.GetBlocks)
s.handleGetHeadersCmd(peer, getHeaders)
case CMDVerack:
// Make sure this peer has sended his version before we start the
// protocol.
if peer.Version() == nil {
return errInvalidHandshake
}
go s.startProtocol(peer)
case CMDAddr:
addressList := msg.Payload.(*payload.AddressList)
for _, addr := range addressList.Addrs {
s.discovery.BackFill(addr.Endpoint.String())
}
}
return nil
}

View file

@ -1,92 +1,114 @@
package network
import (
"os"
"testing"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
"github.com/stretchr/testify/assert"
)
func TestRegisterPeer(t *testing.T) {
s := newTestServer()
func TestSendVersion(t *testing.T) {
var (
s = newTestServer()
p = newLocalPeer(t)
)
s.ListenTCP = 3000
s.UserAgent = "/test/"
p.messageHandler = func(t *testing.T, msg *Message) {
assert.Equal(t, CMDVersion, msg.CommandType())
assert.IsType(t, msg.Payload, &payload.Version{})
version := msg.Payload.(*payload.Version)
assert.NotZero(t, version.Nonce)
assert.Equal(t, uint16(3000), version.Port)
assert.Equal(t, uint64(1), version.Services)
assert.Equal(t, uint32(0), version.Version)
assert.Equal(t, []byte("/test/"), version.UserAgent)
assert.Equal(t, uint32(0), version.StartHeight)
}
s.sendVersion(p)
}
func TestRequestPeerInfo(t *testing.T) {
var (
s = newTestServer()
p = newLocalPeer(t)
)
p.messageHandler = func(t *testing.T, msg *Message) {
assert.Equal(t, CMDGetAddr, msg.CommandType())
assert.Nil(t, msg.Payload)
}
s.requestPeerInfo(p)
}
// Server should reply with a verack after receiving a valid version.
func TestVerackAfterHandleVersionCmd(t *testing.T) {
var (
s = newTestServer()
p = newLocalPeer(t)
)
p.endpoint = util.NewEndpoint("0.0.0.0:3000")
// Should have a verack
p.messageHandler = func(t *testing.T, msg *Message) {
assert.Equal(t, CMDVerack, msg.CommandType())
}
version := payload.NewVersion(1337, 3000, "/NEO-GO/", 0, true)
if err := s.handleVersionCmd(p, version); err != nil {
t.Fatal(err)
}
}
// Server should not reply with a verack after receiving a
// invalid version and disconnects the peer.
func TestServerNotSendsVerack(t *testing.T) {
var (
s = newTestServer()
p = newLocalPeer(t)
)
s.id = 1
go s.run()
assert.NotZero(t, s.id)
assert.Zero(t, s.PeerCount())
p.endpoint = util.NewEndpoint("0.0.0.0:3000")
s.register <- p
lenPeers := 10
for i := 0; i < lenPeers; i++ {
s.register <- newTestPeer()
// Port should mismatch
version := payload.NewVersion(1337, 2000, "/NEO-GO/", 0, true)
err := s.handleVersionCmd(p, version)
assert.NotNil(t, err)
assert.Equal(t, errPortMismatch, err)
// identical id's
version = payload.NewVersion(1, 3000, "/NEO-GO/", 0, true)
err = s.handleVersionCmd(p, version)
assert.NotNil(t, err)
assert.Equal(t, errIdenticalID, err)
}
func TestRequestPeers(t *testing.T) {
var (
s = newTestServer()
p = newLocalPeer(t)
)
p.messageHandler = func(t *testing.T, msg *Message) {
assert.Nil(t, msg.Payload)
assert.Equal(t, CMDGetAddr, msg.CommandType())
}
assert.Equal(t, lenPeers, s.PeerCount())
s.requestPeerInfo(p)
}
func TestUnregisterPeer(t *testing.T) {
s := newTestServer()
go s.run()
peer := newTestPeer()
s.register <- peer
s.register <- newTestPeer()
s.register <- newTestPeer()
assert.Equal(t, 3, s.PeerCount())
s.unregister <- peerDrop{peer, nil}
assert.Equal(t, 2, s.PeerCount())
}
type testNode struct{}
func (t testNode) version() *payload.Version {
return &payload.Version{}
}
func (t testNode) handleProto(msg *Message, p Peer) error {
return nil
}
func newTestServer() *Server {
return &Server{
logger: log.NewLogfmtLogger(os.Stderr),
id: util.RandUint32(1000000, 9999999),
quit: make(chan struct{}, 1),
register: make(chan Peer),
unregister: make(chan peerDrop),
badAddrOp: make(chan func(map[string]bool)),
badAddrOpDone: make(chan struct{}),
peerOp: make(chan func(map[Peer]bool)),
peerOpDone: make(chan struct{}),
proto: testNode{},
func TestRequestHeaders(t *testing.T) {
var (
s = newTestServer()
p = newLocalPeer(t)
)
p.messageHandler = func(t *testing.T, msg *Message) {
assert.IsType(t, &payload.GetBlocks{}, msg.Payload)
assert.Equal(t, CMDGetHeaders, msg.CommandType())
}
}
type testPeer struct {
done chan struct{}
}
func newTestPeer() testPeer {
return testPeer{
done: make(chan struct{}),
}
}
func (p testPeer) Version() *payload.Version {
return &payload.Version{}
}
func (p testPeer) Endpoint() util.Endpoint {
return util.Endpoint{}
}
func (p testPeer) Send(msg *Message) {}
func (p testPeer) Done() chan struct{} {
return p.done
}
func (p testPeer) Disconnect(err error) {
s.requestHeaders(p)
}

View file

@ -1,70 +1,49 @@
package network
import (
"bytes"
"net"
"os"
"sync"
"time"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
)
// TCPPeer represents a connected remote node in the
// network over TCP.
type TCPPeer struct {
// The endpoint of the peer.
// underlying TCP connection.
conn net.Conn
endpoint util.Endpoint
// underlying connection.
conn net.Conn
// The version the peer declared when connecting.
// The version of the peer.
version *payload.Version
// connectedAt is the timestamp the peer connected to
// the network.
connectedAt time.Time
done chan error
closed chan struct{}
disc chan error
// handleProto is the handler that will handle the
// incoming message along with its peer.
handleProto protoHandleFunc
// Done is used to broadcast that this peer has stopped running
// and should be removed as reference.
done chan struct{}
// Every send to this channel will terminate the Peer.
discErr chan error
closed chan struct{}
wg sync.WaitGroup
logger log.Logger
wg sync.WaitGroup
}
// NewTCPPeer creates a new peer from a TCP connection.
func NewTCPPeer(conn net.Conn, fun protoHandleFunc) *TCPPeer {
e := util.NewEndpoint(conn.RemoteAddr().String())
logger := log.NewLogfmtLogger(os.Stderr)
logger = log.With(logger, "component", "peer", "endpoint", e)
func NewTCPPeer(conn net.Conn, proto chan protoTuple) *TCPPeer {
return &TCPPeer{
endpoint: e,
conn: conn,
done: make(chan struct{}),
logger: logger,
connectedAt: time.Now().UTC(),
handleProto: fun,
discErr: make(chan error),
closed: make(chan struct{}),
conn: conn,
done: make(chan error),
closed: make(chan struct{}),
disc: make(chan error),
endpoint: util.NewEndpoint(conn.RemoteAddr().String()),
}
}
// Version implements the Peer interface.
func (p *TCPPeer) Version() *payload.Version {
return p.version
// Send implements the Peer interface. This will encode the message
// to the underlying connection.
func (p *TCPPeer) Send(msg *Message) {
if err := msg.encode(p.conn); err != nil {
select {
case p.disc <- err:
case <-p.closed:
}
}
}
// Endpoint implements the Peer interface.
@ -72,57 +51,19 @@ func (p *TCPPeer) Endpoint() util.Endpoint {
return p.endpoint
}
// Send implements the Peer interface.
func (p *TCPPeer) Send(msg *Message) {
buf := new(bytes.Buffer)
if err := msg.encode(buf); err != nil {
p.discErr <- err
return
}
if _, err := p.conn.Write(buf.Bytes()); err != nil {
p.discErr <- err
return
}
}
// Done implemnets the Peer interface. It use is to
// notify the Node that this peer is no longer available
// for sending messages to.
func (p *TCPPeer) Done() chan struct{} {
// Done implements the Peer interface and notifies
// all other resources operating on it that this peer
// is no longer running.
func (p *TCPPeer) Done() chan error {
return p.done
}
// Disconnect terminates the peer connection.
func (p *TCPPeer) Disconnect(err error) {
select {
case p.discErr <- err:
case <-p.closed:
}
// Version implements the Peer interface.
func (p *TCPPeer) Version() *payload.Version {
return p.version
}
func (p *TCPPeer) run() (err error) {
p.wg.Add(1)
go p.readLoop()
run:
for {
select {
case err = <-p.discErr:
break run
}
}
p.conn.Close()
close(p.closed)
// Close done instead of sending empty struct.
// It could happen that startProtocol in Node never happens
// on connection errors for example.
close(p.done)
p.wg.Wait()
return err
}
func (p *TCPPeer) readLoop() {
func (p *TCPPeer) readLoop(proto chan protoTuple, readErr chan error) {
defer p.wg.Done()
for {
select {
@ -131,23 +72,57 @@ func (p *TCPPeer) readLoop() {
default:
msg := &Message{}
if err := msg.decode(p.conn); err != nil {
p.discErr <- err
readErr <- err
return
}
p.handleMessage(msg)
p.handleMessage(msg, proto)
}
}
}
func (p *TCPPeer) handleMessage(msg *Message) {
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
p.version = version
fallthrough
default:
if err := p.handleProto(msg, p); err != nil {
p.discErr <- err
}
func (p *TCPPeer) handleMessage(msg *Message, proto chan protoTuple) {
switch payload := msg.Payload.(type) {
case *payload.Version:
p.version = payload
}
proto <- protoTuple{
msg: msg,
peer: p,
}
}
func (p *TCPPeer) run(proto chan protoTuple) {
var (
readErr = make(chan error, 1)
err error
)
p.wg.Add(1)
go p.readLoop(proto, readErr)
run:
for {
select {
case err = <-p.disc:
break run
case err = <-readErr:
break run
}
}
// If the peer has not started the protocol with the server
// there will be noone reading from this channel.
select {
case p.done <- err:
default:
}
close(p.closed)
p.conn.Close()
p.wg.Wait()
return
}
// Disconnect implements the Peer interface.
func (p *TCPPeer) Disconnect(reason error) {
p.disc <- reason
}

View file

@ -0,0 +1,79 @@
package network
import (
"net"
"time"
log "github.com/sirupsen/logrus"
)
// TCPTransport allows network communication over TCP.
type TCPTransport struct {
server *Server
listener net.Listener
bindAddr string
proto chan protoTuple
}
// NewTCPTransport return a new TCPTransport that will listen for
// new incoming peer connections.
func NewTCPTransport(s *Server, bindAddr string) *TCPTransport {
return &TCPTransport{
server: s,
bindAddr: bindAddr,
proto: make(chan protoTuple),
}
}
// Consumer implements the Transporter interface.
func (t *TCPTransport) Consumer() <-chan protoTuple {
return t.proto
}
// Dial implements the Transporter interface.
func (t *TCPTransport) Dial(addr string, timeout time.Duration) error {
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return err
}
go t.handleConn(conn)
return nil
}
// Accept implements the Transporter interface.
func (t *TCPTransport) Accept() {
l, err := net.Listen("tcp", t.bindAddr)
if err != nil {
log.Fatalf("TCP listen error %s", err)
return
}
t.listener = l
for {
conn, err := l.Accept()
if err != nil {
log.Warnf("TCP accept error: %s", err)
continue
}
go t.handleConn(conn)
}
}
func (t *TCPTransport) handleConn(conn net.Conn) {
p := NewTCPPeer(conn, t.proto)
t.server.register <- p
// This will block until the peer is stopped running.
p.run(t.proto)
log.Warnf("TCP released peer: %s", p.Endpoint())
}
// Close implements the Transporter interface.
func (t *TCPTransport) Close() {
t.listener.Close()
}
// Proto implements the Transporter interface.
func (t *TCPTransport) Proto() string {
return "tcp"
}

13
pkg/network/transport.go Normal file
View file

@ -0,0 +1,13 @@
package network
import "time"
// Transporter is an interface that allows us to abstract
// any form of communication between the server and its peers.
type Transporter interface {
Consumer() <-chan protoTuple
Dial(addr string, timeout time.Duration) error
Accept()
Proto() string
Close()
}