From 4023661cf18227088db937f9938b9a876e0654d0 Mon Sep 17 00:00:00 2001 From: Anthony De Meulemeester Date: Fri, 9 Mar 2018 16:55:25 +0100 Subject: [PATCH] Refactor of the Go node (#44) * added headersOp for safely processing headers * Better handling of protocol messages. * housekeeping + cleanup tests * Added more blockchain logic + unit tests * fixed unreachable error. * added structured logging for all (node) components. * added relay flag + bumped version --- "\n\n" | 0 Gopkg.lock | 26 +- Gopkg.toml | 4 + Makefile | 2 +- VERSION | 2 +- cli/server/server.go | 19 +- main.go | 15 + pkg/core/block.go | 70 ++-- pkg/core/block_test.go | 80 ++--- pkg/core/blockchain.go | 297 +++++++++-------- pkg/core/blockchain_test.go | 80 +++-- pkg/core/cache.go | 73 ++++ pkg/core/header_hash_list.go | 66 ++++ pkg/core/helper_test.go | 40 +++ pkg/core/leveldb_store.go | 1 - pkg/core/memory_store.go | 8 +- pkg/core/store.go | 33 +- pkg/core/transaction/transaction.go | 2 +- pkg/core/transaction/type.go | 28 +- pkg/core/util.go | 12 + pkg/crypto/base58.go | 1 + pkg/network/message.go | 79 ++--- pkg/network/message_test.go | 33 +- pkg/network/node.go | 219 ++++++++++++ pkg/network/node_test.go | 7 + pkg/network/payload/addr.go | 90 ----- pkg/network/payload/addr_test.go | 55 --- pkg/network/payload/address.go | 85 +++++ pkg/network/payload/address_test.go | 51 +++ pkg/network/payload/headers.go | 1 - pkg/network/payload/headers_test.go | 8 +- pkg/network/payload/inventory.go | 4 +- pkg/network/payload/version.go | 73 ++-- pkg/network/peer.go | 50 +-- pkg/network/protocol.go | 22 ++ pkg/network/rpc.go | 129 -------- pkg/network/server.go | 496 ++++++++++++---------------- pkg/network/server_test.go | 98 ++++-- pkg/network/tcp.go | 254 -------------- pkg/network/tcp_peer.go | 134 ++++++++ pkg/util/endpoint.go | 11 +- pkg/util/uint256.go | 2 +- pkg/wallet/wif.go | 2 +- 43 files changed, 1497 insertions(+), 1265 deletions(-) create mode 100644 "\n\n" create mode 100644 main.go create mode 100644 pkg/core/cache.go create mode 100644 pkg/core/header_hash_list.go create mode 100644 pkg/core/helper_test.go create mode 100644 pkg/core/util.go create mode 100644 pkg/network/node.go create mode 100644 pkg/network/node_test.go delete mode 100644 pkg/network/payload/addr.go delete mode 100644 pkg/network/payload/addr_test.go create mode 100644 pkg/network/payload/address.go create mode 100644 pkg/network/payload/address_test.go create mode 100644 pkg/network/protocol.go delete mode 100644 pkg/network/rpc.go delete mode 100644 pkg/network/tcp.go create mode 100644 pkg/network/tcp_peer.go diff --git "a/\n\n" "b/\n\n" new file mode 100644 index 000000000..e69de29bb diff --git a/Gopkg.lock b/Gopkg.lock index 748c10078..9bb4b30eb 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -19,12 +19,36 @@ revision = "346938d642f2ec3594ed81d874461961cd0faa76" version = "v1.1.0" +[[projects]] + name = "github.com/go-kit/kit" + packages = ["log"] + revision = "4dc7be5d2d12881735283bcab7352178e190fc71" + version = "v0.6.0" + +[[projects]] + name = "github.com/go-logfmt/logfmt" + packages = ["."] + revision = "390ab7935ee28ec6b286364bba9b4dd6410cb3d5" + version = "v0.3.0" + +[[projects]] + name = "github.com/go-stack/stack" + packages = ["."] + revision = "259ab82a6cad3992b4e21ff5cac294ccb06474bc" + version = "v1.7.0" + [[projects]] branch = "master" name = "github.com/golang/snappy" packages = ["."] revision = "553a641470496b2327abcac10b36396bd98e45c9" +[[projects]] + branch = "master" + name = "github.com/kr/logfmt" + packages = ["."] + revision = "b84e30acd515aadc4b783ad4ff83aff3299bdfe0" + [[projects]] name = "github.com/pmezard/go-difflib" packages = ["difflib"] @@ -98,6 +122,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "069a738aa1487766b26f9efb8103d2ce0526d43c83049cb5b792f0edf91568de" + inputs-digest = "53597073e919ad7bf52895a19f8b8526d12d666862fb1d36b4a9756e0499da5a" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index e02ed3b24..ebf2ef9a3 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -51,3 +51,7 @@ [[constraint]] name = "github.com/stretchr/testify" version = "1.2.1" + +[[constraint]] + name = "github.com/go-kit/kit" + version = "0.6.0" diff --git a/Makefile b/Makefile index 7cc8e0666..a7b46c52c 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ push-tag: git push origin ${BRANCH} --tags run: build - ./bin/neo-go node -seed ${SEEDS} -tcp ${PORT} + ./bin/neo-go node -seed ${SEEDS} -tcp ${PORT} --relay true test: @go test ./... -cover diff --git a/VERSION b/VERSION index d21d277be..4e8f395fa 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.25.0 +0.26.0 diff --git a/cli/server/server.go b/cli/server/server.go index cce204982..7ccd57222 100644 --- a/cli/server/server.go +++ b/cli/server/server.go @@ -16,6 +16,7 @@ func NewCommand() cli.Command { Flags: []cli.Flag{ cli.IntFlag{Name: "tcp"}, cli.IntFlag{Name: "rpc"}, + cli.BoolFlag{Name: "relay, r"}, cli.StringFlag{Name: "seed"}, cli.BoolFlag{Name: "privnet, p"}, cli.BoolFlag{Name: "mainnet, m"}, @@ -25,12 +26,6 @@ func NewCommand() cli.Command { } func startServer(ctx *cli.Context) error { - opts := network.StartOpts{ - Seeds: parseSeeds(ctx.String("seed")), - TCP: ctx.Int("tcp"), - RPC: ctx.Int("rpc"), - } - net := network.ModePrivNet if ctx.Bool("testnet") { net = network.ModeTestNet @@ -39,8 +34,16 @@ func startServer(ctx *cli.Context) error { net = network.ModeMainNet } - s := network.NewServer(net) - s.Start(opts) + cfg := network.Config{ + UserAgent: "/NEO-GO:0.26.0/", + ListenTCP: uint16(ctx.Int("tcp")), + Seeds: parseSeeds(ctx.String("seed")), + Net: net, + Relay: ctx.Bool("relay"), + } + + s := network.NewServer(cfg) + s.Start() return nil } diff --git a/main.go b/main.go new file mode 100644 index 000000000..b02e7f247 --- /dev/null +++ b/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "os" + + log "github.com/go-kit/kit/log" +) + +func main() { + logger := log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr)) + logger.Log("hello", true) + + logger = log.With(logger, "module", "node") + logger.Log("foo", true) +} diff --git a/pkg/core/block.go b/pkg/core/block.go index fcdcecd59..843b62ce3 100644 --- a/pkg/core/block.go +++ b/pkg/core/block.go @@ -33,6 +33,9 @@ type BlockBase struct { _ uint8 // padding // Script used to validate the block Script *transaction.Witness + + // hash of this block, created when binary encoded. + hash util.Uint256 } // DecodeBinary implements the payload interface. @@ -68,19 +71,35 @@ func (b *BlockBase) DecodeBinary(r io.Reader) error { } b.Script = &transaction.Witness{} - return b.Script.DecodeBinary(r) + if err := b.Script.DecodeBinary(r); err != nil { + return err + } + + // Make the hash of the block here so we dont need to do this + // again. + hash, err := b.createHash() + if err != nil { + return err + } + b.hash = hash + return nil } -// Hash returns the hash of the block. +// Hash return the hash of the block. +func (b *BlockBase) Hash() util.Uint256 { + return b.hash +} + +// createHash creates the hash of the block. // When calculating the hash value of the block, instead of calculating the entire block, // only first seven fields in the block head will be calculated, which are // version, PrevBlock, MerkleRoot, timestamp, and height, the nonce, NextMiner. // Since MerkleRoot already contains the hash value of all transactions, // the modification of transaction will influence the hash value of the block. -func (b *BlockBase) Hash() (hash util.Uint256, err error) { +func (b *BlockBase) createHash() (hash util.Uint256, err error) { buf := new(bytes.Buffer) if err = b.encodeHashableFields(buf); err != nil { - return + return hash, err } // Double hash the encoded fields. @@ -92,15 +111,25 @@ func (b *BlockBase) Hash() (hash util.Uint256, err error) { // encodeHashableFields will only encode the fields used for hashing. // see Hash() for more information about the fields. func (b *BlockBase) encodeHashableFields(w io.Writer) error { - err := binary.Write(w, binary.LittleEndian, &b.Version) - err = binary.Write(w, binary.LittleEndian, &b.PrevHash) - err = binary.Write(w, binary.LittleEndian, &b.MerkleRoot) - err = binary.Write(w, binary.LittleEndian, &b.Timestamp) - err = binary.Write(w, binary.LittleEndian, &b.Index) - err = binary.Write(w, binary.LittleEndian, &b.ConsensusData) - err = binary.Write(w, binary.LittleEndian, &b.NextConsensus) - - return err + if err := binary.Write(w, binary.LittleEndian, &b.Version); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, &b.PrevHash); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, &b.MerkleRoot); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, &b.Timestamp); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, &b.Index); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, &b.ConsensusData); err != nil { + return err + } + return binary.Write(w, binary.LittleEndian, &b.NextConsensus) } // EncodeBinary implements the Payload interface @@ -108,21 +137,16 @@ func (b *BlockBase) EncodeBinary(w io.Writer) error { if err := b.encodeHashableFields(w); err != nil { return err } - - // padding if err := binary.Write(w, binary.LittleEndian, uint8(1)); err != nil { return err } - - // script return b.Script.EncodeBinary(w) } // Header holds the head info of a block type Header struct { BlockBase - // fixed to 0 - _ uint8 // padding + _ uint8 // padding fixed to 0 } // Verify the integrity of the header @@ -150,15 +174,12 @@ func (h *Header) EncodeBinary(w io.Writer) error { if err := h.BlockBase.EncodeBinary(w); err != nil { return err } - - // padding return binary.Write(w, binary.LittleEndian, uint8(0)) } // Block represents one block in the chain. type Block struct { BlockBase - // transaction list Transactions []*transaction.Transaction } @@ -205,11 +226,10 @@ func (b *Block) DecodeBinary(r io.Reader) error { lentx := util.ReadVarUint(r) b.Transactions = make([]*transaction.Transaction, lentx) for i := 0; i < int(lentx); i++ { - tx := &transaction.Transaction{} - if err := tx.DecodeBinary(r); err != nil { + b.Transactions[i] = &transaction.Transaction{} + if err := b.Transactions[i].DecodeBinary(r); err != nil { return err } - b.Transactions[i] = tx } return nil diff --git a/pkg/core/block_test.go b/pkg/core/block_test.go index 0fb6d1098..200b00589 100644 --- a/pkg/core/block_test.go +++ b/pkg/core/block_test.go @@ -8,6 +8,7 @@ import ( "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/util" + "github.com/stretchr/testify/assert" ) func TestDecodeBlock(t *testing.T) { @@ -29,63 +30,24 @@ func TestDecodeBlock(t *testing.T) { if err := block.DecodeBinary(bytes.NewReader(rawBlockBytes)); err != nil { t.Fatal(err) } - if block.Index != uint32(rawBlockIndex) { - t.Fatalf("expected the index to the block to be %d got %d", rawBlockIndex, block.Index) - } - if block.Timestamp != uint32(rawBlockTimestamp) { - t.Fatalf("expected timestamp to be %d got %d", rawBlockTimestamp, block.Timestamp) - } - if block.ConsensusData != uint64(rawBlockConsensusData) { - t.Fatalf("expected consensus data to be %d got %d", rawBlockConsensusData, block.ConsensusData) - } - if block.PrevHash.String() != rawBlockPrevHash { - t.Fatalf("expected prev block hash to be %s got %s", rawBlockPrevHash, block.PrevHash) - } - hash, err := block.Hash() - if err != nil { - t.Fatal(err) - } - if hash.String() != rawBlockHash { - t.Fatalf("expected hash of the block to be %s got %s", rawBlockHash, hash) - } -} - -func newBlockBase() BlockBase { - return BlockBase{ - Version: 0, - PrevHash: sha256.Sum256([]byte("a")), - MerkleRoot: sha256.Sum256([]byte("b")), - Timestamp: 999, - Index: 1, - ConsensusData: 1111, - NextConsensus: util.Uint160{}, - Script: &transaction.Witness{ - VerificationScript: []byte{0x0}, - InvocationScript: []byte{0x1}, - }, - } + assert.Equal(t, uint32(rawBlockIndex), block.Index) + assert.Equal(t, uint32(rawBlockTimestamp), block.Timestamp) + assert.Equal(t, uint64(rawBlockConsensusData), block.ConsensusData) + assert.Equal(t, rawBlockPrevHash, block.PrevHash.String()) + assert.Equal(t, rawBlockHash, block.Hash().String()) } func TestHashBlockEqualsHashHeader(t *testing.T) { - base := newBlockBase() - b := &Block{BlockBase: base} - head := &Header{BlockBase: base} - - bhash, _ := b.Hash() - headhash, _ := head.Hash() - if bhash != headhash { - t.Fatalf("expected both hashes to be equal %s and %s", bhash, headhash) - } + block := newBlock(0) + assert.Equal(t, block.Hash(), block.Header().Hash()) } func TestBlockVerify(t *testing.T) { - block := &Block{ - BlockBase: newBlockBase(), - Transactions: []*transaction.Transaction{ - {Type: transaction.MinerType}, - {Type: transaction.IssueType}, - }, - } + block := newBlock( + 0, + newTX(transaction.MinerType), + newTX(transaction.IssueType), + ) if !block.Verify(false) { t.Fatal("block should be verified") @@ -109,3 +71,19 @@ func TestBlockVerify(t *testing.T) { t.Fatal("block should not by verified") } } + +func newBlockBase() BlockBase { + return BlockBase{ + Version: 0, + PrevHash: sha256.Sum256([]byte("a")), + MerkleRoot: sha256.Sum256([]byte("b")), + Timestamp: 999, + Index: 1, + ConsensusData: 1111, + NextConsensus: util.Uint160{}, + Script: &transaction.Witness{ + VerificationScript: []byte{0x0}, + InvocationScript: []byte{0x1}, + }, + } +} diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 39a04e3c0..7495ceb41 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -3,17 +3,19 @@ package core import ( "bytes" "encoding/binary" - "log" - "sync" + "fmt" + "os" + "sync/atomic" "time" "github.com/CityOfZion/neo-go/pkg/util" + log "github.com/go-kit/kit/log" ) // tuning parameters const ( secondsPerBlock = 15 - writeHdrBatchCnt = 2000 + headerBatchCount = 2000 ) var ( @@ -22,188 +24,217 @@ var ( // Blockchain holds the chain. type Blockchain struct { - logger *log.Logger + logger log.Logger // Any object that satisfies the BlockchainStorer interface. Store - // current index of the heighest block - currentBlockHeight uint32 + // Current index/height of the highest block. + // Read access should always be called by BlockHeight(). + // Writes access should only happen in persist(). + blockHeight uint32 - // number of headers stored + // Number of headers stored. storedHeaderCount uint32 - mtx sync.RWMutex + blockCache *Cache - // index of headers hashes - headerIndex []util.Uint256 + startHash util.Uint256 + + // Only for operating on the headerList. + headersOp chan headersOpFunc + headersOpDone chan struct{} } -// NewBlockchain returns a pointer to a Blockchain. -func NewBlockchain(s Store, l *log.Logger, startHash util.Uint256) *Blockchain { +type headersOpFunc func(headerList *HeaderHashList) + +// NewBlockchain creates a new Blockchain object. +func NewBlockchain(s Store, startHash util.Uint256) *Blockchain { + logger := log.NewLogfmtLogger(os.Stderr) + logger = log.With(logger, "component", "blockchain") + bc := &Blockchain{ - logger: l, - Store: s, + logger: logger, + Store: s, + headersOp: make(chan headersOpFunc), + headersOpDone: make(chan struct{}), + startHash: startHash, + blockCache: NewCache(), } - - // Starthash is 0, so we will create the genesis block. - if startHash.Equals(util.Uint256{}) { - bc.logger.Fatal("genesis block not yet implemented") - } - - bc.headerIndex = []util.Uint256{startHash} + go bc.run() + bc.init() return bc } -// genesisBlock creates the genesis block for the chain. -// hash of the genesis block: -// d42561e3d30e15be6400b6df2f328e02d2bf6354c41dce433bc57687c82144bf -func (bc *Blockchain) genesisBlock() *Block { - timestamp := uint32(time.Date(2016, 7, 15, 15, 8, 21, 0, time.UTC).Unix()) +func (bc *Blockchain) init() { + // for the initial header, for now + bc.storedHeaderCount = 1 +} - // TODO: for testing I will hardcode the merkleroot. - // This let's me focus on the bringing all the puzzle pieces - // togheter much faster. - // For more information about the genesis block: - // https://neotracker.io/block/height/0 - mr, _ := util.Uint256DecodeString("803ff4abe3ea6533bcc0be574efa02f83ae8fdc651c879056b0d9be336c01bf4") - - return &Block{ - BlockBase: BlockBase{ - Version: 0, - PrevHash: util.Uint256{}, - MerkleRoot: mr, - Timestamp: timestamp, - Index: 0, - ConsensusData: 2083236893, // nioctib ^^ - NextConsensus: util.Uint160{}, // todo - }, +func (bc *Blockchain) run() { + headerList := NewHeaderHashList(bc.startHash) + for { + select { + case op := <-bc.headersOp: + op(headerList) + bc.headersOpDone <- struct{}{} + } } } -// AddBlock (to be continued after headers is finished..) func (bc *Blockchain) AddBlock(block *Block) error { - // TODO: caching - headerLen := len(bc.headerIndex) + if !bc.blockCache.Has(block.Hash()) { + bc.blockCache.Add(block.Hash(), block) + } + headerLen := int(bc.HeaderHeight() + 1) if int(block.Index-1) >= headerLen { return nil } - if int(block.Index) == headerLen { // todo: if (VerifyBlocks && !block.Verify()) return false; } - - if int(block.Index) < headerLen { - return nil - } - - return nil + return bc.AddHeaders(block.Header()) } -func (bc *Blockchain) addHeader(header *Header) error { - return bc.AddHeaders(header) +func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) { + var ( + start = time.Now() + batch = Batch{} + ) + + bc.headersOp <- func(headerList *HeaderHashList) { + for _, h := range headers { + if int(h.Index-1) >= headerList.Len() { + err = fmt.Errorf( + "height of block higher then current header height %d > %d\n", + h.Index, headerList.Len(), + ) + return + } + if int(h.Index) < headerList.Len() { + continue + } + if !h.Verify() { + err = fmt.Errorf("header %v is invalid", h) + return + } + if err = bc.processHeader(h, batch, headerList); err != nil { + return + } + } + + // TODO: Implement caching strategy. + if len(batch) > 0 { + if err = bc.writeBatch(batch); err != nil { + return + } + bc.logger.Log( + "msg", "done processing headers", + "index", headerList.Len()-1, + "took", time.Since(start).Seconds(), + ) + } + } + <-bc.headersOpDone + return err } -// AddHeaders processes the given headers. -func (bc *Blockchain) AddHeaders(headers ...*Header) error { - start := time.Now() - - bc.mtx.Lock() - defer bc.mtx.Unlock() - - batch := Batch{} - for _, h := range headers { - if int(h.Index-1) >= len(bc.headerIndex) { - bc.logger.Printf("height of block higher then header index %d %d\n", - h.Index, len(bc.headerIndex)) - break - } - if int(h.Index) < len(bc.headerIndex) { - continue - } - if !h.Verify() { - bc.logger.Printf("header %v is invalid", h) - break - } - if err := bc.processHeader(h, batch); err != nil { - return err - } - } - - // TODO: Implement caching strategy. - if len(batch) > 0 { - // Write all batches. - if err := bc.writeBatch(batch); err != nil { - return err - } - - bc.logger.Printf("done processing headers up to index %d took %f Seconds", - bc.HeaderHeight(), time.Since(start).Seconds()) - } - - return nil -} - -// processHeader processes 1 header. -func (bc *Blockchain) processHeader(h *Header, batch Batch) error { - hash, err := h.Hash() - if err != nil { - return err - } - bc.headerIndex = append(bc.headerIndex, hash) - - for int(h.Index)-writeHdrBatchCnt >= int(bc.storedHeaderCount) { - // hdrsToWrite = bc.headerIndex[bc.storedHeaderCount : bc.storedHeaderCount+writeHdrBatchCnt] - - // NOTE: from original #c to be implemented: - // - // w.Write(header_index.Skip((int)stored_header_count).Take(2000).ToArray()); - // w.Flush(); - // batch.Put(SliceBuilder.Begin(DataEntryPrefix.IX_HeaderHashList).Add(stored_header_count), ms.ToArray()); - - bc.storedHeaderCount += writeHdrBatchCnt - } +// processHeader processes the given header. Note that this is only thread safe +// if executed in headers operation. +func (bc *Blockchain) processHeader(h *Header, batch Batch, headerList *HeaderHashList) error { + headerList.Add(h.Hash()) buf := new(bytes.Buffer) + for int(h.Index)-headerBatchCount >= int(bc.storedHeaderCount) { + if err := headerList.Write(buf, int(bc.storedHeaderCount), headerBatchCount); err != nil { + return err + } + key := makeEntryPrefixInt(preIXHeaderHashList, int(bc.storedHeaderCount)) + batch[&key] = buf.Bytes() + bc.storedHeaderCount += headerBatchCount + buf.Reset() + } + + buf.Reset() if err := h.EncodeBinary(buf); err != nil { return err } - preBlock := preDataBlock.add(hash.BytesReverse()) - batch[&preBlock] = buf.Bytes() - preHeader := preSYSCurrentHeader.toSlice() - batch[&preHeader] = hashAndIndexToBytes(hash, h.Index) + key := makeEntryPrefix(preDataBlock, h.Hash().BytesReverse()) + batch[&key] = buf.Bytes() + key = preSYSCurrentHeader.bytes() + batch[&key] = hashAndIndexToBytes(h.Hash(), h.Index) return nil } -// CurrentBlockHash return the lastest hash in the header index. -func (bc *Blockchain) CurrentBlockHash() (hash util.Uint256) { - if len(bc.headerIndex) == 0 { - return - } - if len(bc.headerIndex) < int(bc.currentBlockHeight) { - return - } +func (bc *Blockchain) persistBlock(block *Block) error { + bc.blockHeight = block.Index + return nil +} - return bc.headerIndex[bc.currentBlockHeight] +func (bc *Blockchain) persist() (err error) { + var ( + persisted = 0 + lenCache = bc.blockCache.Len() + ) + + for lenCache > persisted { + if bc.HeaderHeight()+1 <= bc.BlockHeight() { + break + } + bc.headersOp <- func(headerList *HeaderHashList) { + hash := headerList.Get(int(bc.BlockHeight() + 1)) + if block, ok := bc.blockCache.GetBlock(hash); ok { + if err = bc.persistBlock(block); err != nil { + return + } + bc.blockCache.Delete(hash) + persisted++ + } else { + bc.logger.Log( + "msg", "block not found in cache", + "hash", block.Hash(), + ) + } + } + <-bc.headersOpDone + } + return +} + +// CurrentBlockHash returns the heighest processed block hash. +func (bc *Blockchain) CurrentBlockHash() (hash util.Uint256) { + bc.headersOp <- func(headerList *HeaderHashList) { + hash = headerList.Get(int(bc.BlockHeight())) + } + <-bc.headersOpDone + return } // CurrentHeaderHash returns the hash of the latest known header. func (bc *Blockchain) CurrentHeaderHash() (hash util.Uint256) { - return bc.headerIndex[len(bc.headerIndex)-1] + bc.headersOp <- func(headerList *HeaderHashList) { + hash = headerList.Last() + } + <-bc.headersOpDone + return } -// BlockHeight return the height/index of the latest block this node has. +// BlockHeight returns the height/index of the highest block. func (bc *Blockchain) BlockHeight() uint32 { - return bc.currentBlockHeight + return atomic.LoadUint32(&bc.blockHeight) } -// HeaderHeight returns the current index of the headers. -func (bc *Blockchain) HeaderHeight() uint32 { - return uint32(len(bc.headerIndex)) - 1 +// HeaderHeight returns the index/height of the highest header. +func (bc *Blockchain) HeaderHeight() (n uint32) { + bc.headersOp <- func(headerList *HeaderHashList) { + n = uint32(headerList.Len() - 1) + } + <-bc.headersOpDone + return } func hashAndIndexToBytes(h util.Uint256, index uint32) []byte { diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 6cd69bdcb..fab77b225 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -1,48 +1,70 @@ package core import ( - "log" - "os" "testing" - "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/util" + "github.com/stretchr/testify/assert" ) func TestNewBlockchain(t *testing.T) { startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099") - bc := NewBlockchain(nil, nil, startHash) + bc := NewBlockchain(nil, startHash) - want := uint32(0) - if have := bc.BlockHeight(); want != have { - t.Fatalf("expected %d got %d", want, have) - } - if have := bc.HeaderHeight(); want != have { - t.Fatalf("expected %d got %d", want, have) - } - if have := bc.storedHeaderCount; want != have { - t.Fatalf("expected %d got %d", want, have) - } - if !bc.CurrentBlockHash().Equals(startHash) { - t.Fatalf("expected current block hash to be %d got %s", startHash, bc.CurrentBlockHash()) - } + assert.Equal(t, uint32(0), bc.BlockHeight()) + assert.Equal(t, uint32(0), bc.HeaderHeight()) + assert.Equal(t, uint32(1), bc.storedHeaderCount) + assert.Equal(t, startHash, bc.startHash) } func TestAddHeaders(t *testing.T) { - startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099") - bc := NewBlockchain(NewMemoryStore(), log.New(os.Stdout, "", 0), startHash) - - h1 := &Header{BlockBase: BlockBase{Version: 0, Index: 1, Script: &transaction.Witness{}}} - h2 := &Header{BlockBase: BlockBase{Version: 0, Index: 2, Script: &transaction.Witness{}}} - h3 := &Header{BlockBase: BlockBase{Version: 0, Index: 3, Script: &transaction.Witness{}}} + bc := newTestBC() + h1 := newBlock(1).Header() + h2 := newBlock(2).Header() + h3 := newBlock(3).Header() if err := bc.AddHeaders(h1, h2, h3); err != nil { t.Fatal(err) } - if want, have := h3.Index, bc.HeaderHeight(); want != have { - t.Fatalf("expected header height of %d got %d", want, have) - } - if want, have := uint32(0), bc.storedHeaderCount; want != have { - t.Fatalf("expected stored header count to be %d got %d", want, have) - } + + assert.Equal(t, 0, bc.blockCache.Len()) + assert.Equal(t, h3.Index, bc.HeaderHeight()) + assert.Equal(t, uint32(1), bc.storedHeaderCount) + assert.Equal(t, uint32(0), bc.BlockHeight()) + assert.Equal(t, h3.Hash(), bc.CurrentHeaderHash()) +} + +func TestAddBlock(t *testing.T) { + bc := newTestBC() + blocks := []*Block{ + newBlock(1), + newBlock(2), + newBlock(3), + } + + for i := 0; i < len(blocks); i++ { + if err := bc.AddBlock(blocks[i]); err != nil { + t.Fatal(err) + } + } + + lastBlock := blocks[len(blocks)-1] + assert.Equal(t, 3, bc.blockCache.Len()) + assert.Equal(t, lastBlock.Index, bc.HeaderHeight()) + assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash()) + assert.Equal(t, uint32(1), bc.storedHeaderCount) + + if err := bc.persist(); err != nil { + t.Fatal(err) + } + + assert.Equal(t, lastBlock.Index, bc.BlockHeight()) + assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash()) + assert.Equal(t, 0, bc.blockCache.Len()) +} + +func newTestBC() *Blockchain { + startHash, _ := util.Uint256DecodeString("a") + bc := NewBlockchain(NewMemoryStore(), startHash) + return bc } diff --git a/pkg/core/cache.go b/pkg/core/cache.go new file mode 100644 index 000000000..3cd51cf34 --- /dev/null +++ b/pkg/core/cache.go @@ -0,0 +1,73 @@ +package core + +import ( + "sync" + + "github.com/CityOfZion/neo-go/pkg/util" +) + +// Cache is data structure with fixed type key of Uint256, but has a +// generic value. Used for block and header cash types. +type Cache struct { + lock sync.RWMutex + m map[util.Uint256]interface{} +} + +// NewCache returns a ready to use Cache object. +func NewCache() *Cache { + return &Cache{ + m: make(map[util.Uint256]interface{}), + } +} + +// GetBlock will return a Block type from the cache. +func (c *Cache) GetBlock(h util.Uint256) (block *Block, ok bool) { + c.lock.RLock() + defer c.lock.RUnlock() + return c.getBlock(h) +} + +func (c *Cache) getBlock(h util.Uint256) (block *Block, ok bool) { + if v, b := c.m[h]; b { + block, ok = v.(*Block) + return + } + return +} + +// Add adds the given hash along with its value to the cache. +func (c *Cache) Add(h util.Uint256, v interface{}) { + c.lock.Lock() + defer c.lock.Unlock() + c.add(h, v) +} + +func (c *Cache) add(h util.Uint256, v interface{}) { + c.m[h] = v +} + +func (c *Cache) has(h util.Uint256) bool { + _, ok := c.m[h] + return ok +} + +// Hash returns whether the cach contains the given hash. +func (c *Cache) Has(h util.Uint256) bool { + c.lock.Lock() + defer c.lock.Unlock() + return c.has(h) +} + +// Len return the number of items present in the cache. +func (c *Cache) Len() int { + c.lock.RLock() + defer c.lock.RUnlock() + return len(c.m) +} + +// Delete removes the item out of the cache. +func (c *Cache) Delete(h util.Uint256) { + c.lock.Lock() + defer c.lock.Unlock() + delete(c.m, h) +} diff --git a/pkg/core/header_hash_list.go b/pkg/core/header_hash_list.go new file mode 100644 index 000000000..f47169b84 --- /dev/null +++ b/pkg/core/header_hash_list.go @@ -0,0 +1,66 @@ +package core + +import ( + "encoding/binary" + "io" + + "github.com/CityOfZion/neo-go/pkg/util" +) + +// A HeaderHashList represents a list of header hashes. +type HeaderHashList struct { + hashes []util.Uint256 +} + +// NewHeaderHashList return a new pointer to a HeaderHashList. +func NewHeaderHashList(hashes ...util.Uint256) *HeaderHashList { + return &HeaderHashList{ + hashes: hashes, + } +} + +// Add appends the given hash to the list of hashes. +func (l *HeaderHashList) Add(h util.Uint256) { + l.hashes = append(l.hashes, h) +} + +// Len return the length of the underlying hashes slice. +func (l *HeaderHashList) Len() int { + return len(l.hashes) +} + +// Get returns the hash by the given index. +func (l *HeaderHashList) Get(i int) util.Uint256 { + if l.Len() < i { + return util.Uint256{} + } + return l.hashes[i] +} + +// Last return the last hash in the HeaderHashList. +func (l *HeaderHashList) Last() util.Uint256 { + return l.hashes[l.Len()-1] +} + +// Slice return a subslice of the underlying hashes. +// Subsliced from start to end. +// Example: +// headers := headerList.Slice(0, 2000) +func (l *HeaderHashList) Slice(start, end int) []util.Uint256 { + return l.hashes[start:end] +} + +// WriteTo will write n underlying hashes to the given io.Writer +// starting from start. +func (l *HeaderHashList) Write(w io.Writer, start, n int) error { + if err := util.WriteVarUint(w, uint64(n)); err != nil { + return err + } + hashes := l.Slice(start, start+n) + for _, hash := range hashes { + if err := binary.Write(w, binary.LittleEndian, hash); err != nil { + return err + } + } + return nil +} diff --git a/pkg/core/helper_test.go b/pkg/core/helper_test.go new file mode 100644 index 000000000..9933fffde --- /dev/null +++ b/pkg/core/helper_test.go @@ -0,0 +1,40 @@ +package core + +import ( + "crypto/sha256" + "time" + + "github.com/CityOfZion/neo-go/pkg/core/transaction" + "github.com/CityOfZion/neo-go/pkg/util" +) + +func newBlock(index uint32, txs ...*transaction.Transaction) *Block { + b := &Block{ + BlockBase: BlockBase{ + Version: 0, + PrevHash: sha256.Sum256([]byte("a")), + MerkleRoot: sha256.Sum256([]byte("b")), + Timestamp: uint32(time.Now().UTC().Unix()), + Index: index, + ConsensusData: 1111, + NextConsensus: util.Uint160{}, + Script: &transaction.Witness{ + VerificationScript: []byte{0x0}, + InvocationScript: []byte{0x1}, + }, + }, + Transactions: txs, + } + hash, err := b.createHash() + if err != nil { + panic(err) + } + b.hash = hash + return b +} + +func newTX(t transaction.TXType) *transaction.Transaction { + return &transaction.Transaction{ + Type: t, + } +} diff --git a/pkg/core/leveldb_store.go b/pkg/core/leveldb_store.go index 037bc3525..072fc7135 100644 --- a/pkg/core/leveldb_store.go +++ b/pkg/core/leveldb_store.go @@ -21,6 +21,5 @@ func (s *LevelDBStore) writeBatch(batch Batch) error { for k, v := range batch { b.Put(*k, v) } - return s.db.Write(b, nil) } diff --git a/pkg/core/memory_store.go b/pkg/core/memory_store.go index ef9e4c126..5a47601e6 100644 --- a/pkg/core/memory_store.go +++ b/pkg/core/memory_store.go @@ -1,6 +1,7 @@ package core -// MemoryStore is an in memory implementation of a BlockChainStorer. +// MemoryStore is an in memory implementation of a BlockChainStorer +// that should only be used for testing. type MemoryStore struct { } @@ -14,5 +15,10 @@ func (m *MemoryStore) write(key, value []byte) error { } func (m *MemoryStore) writeBatch(batch Batch) error { + for k, v := range batch { + if err := m.write(*k, v); err != nil { + return err + } + } return nil } diff --git a/pkg/core/store.go b/pkg/core/store.go index 0e18c75f3..fd0ca6b98 100644 --- a/pkg/core/store.go +++ b/pkg/core/store.go @@ -1,17 +1,13 @@ package core +import ( + "bytes" + "encoding/binary" +) + type dataEntry uint8 -func (e dataEntry) add(b []byte) []byte { - dest := make([]byte, len(b)+1) - dest[0] = byte(e) - for i := 1; i < len(b); i++ { - dest[i] = b[i] - } - return dest -} - -func (e dataEntry) toSlice() []byte { +func (e dataEntry) bytes() []byte { return []byte{byte(e)} } @@ -32,6 +28,21 @@ const ( preSYSVersion dataEntry = 0xf0 ) +func makeEntryPrefixInt(e dataEntry, n int) []byte { + buf := new(bytes.Buffer) + binary.Write(buf, binary.LittleEndian, n) + return makeEntryPrefix(e, buf.Bytes()) +} + +func makeEntryPrefix(e dataEntry, b []byte) []byte { + dest := make([]byte, len(b)+1) + dest[0] = byte(e) + for i := 1; i < len(b); i++ { + dest[i] = b[i] + } + return dest +} + // Store is anything that can persist and retrieve the blockchain. type Store interface { write(k, v []byte) error @@ -39,5 +50,5 @@ type Store interface { } // Batch is a data type used to store data for later batch operations -// by any Store. +// that can be used by any Store interface implementation. type Batch map[*[]byte][]byte diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index deeab0527..92e4c9a38 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -10,7 +10,7 @@ import ( // Transaction is a process recorded in the NEO blockchain. type Transaction struct { // The type of the transaction. - Type TransactionType + Type TXType // The trading version which is currently 0. Version uint8 diff --git a/pkg/core/transaction/type.go b/pkg/core/transaction/type.go index fbf95a608..9b4cfb9e8 100644 --- a/pkg/core/transaction/type.go +++ b/pkg/core/transaction/type.go @@ -1,26 +1,26 @@ package transaction -// TransactionType is the type of a transaction. -type TransactionType uint8 +// TXType is the type of a transaction. +type TXType uint8 // All processes in NEO system are recorded in transactions. // There are several types of transactions. const ( - MinerType TransactionType = 0x00 - IssueType TransactionType = 0x01 - ClaimType TransactionType = 0x02 - EnrollmentType TransactionType = 0x20 - VotingType TransactionType = 0x24 - RegisterType TransactionType = 0x40 - ContractType TransactionType = 0x80 - StateType TransactionType = 0x90 - AgencyType TransactionType = 0xb0 - PublishType TransactionType = 0xd0 - InvocationType TransactionType = 0xd1 + MinerType TXType = 0x00 + IssueType TXType = 0x01 + ClaimType TXType = 0x02 + EnrollmentType TXType = 0x20 + VotingType TXType = 0x24 + RegisterType TXType = 0x40 + ContractType TXType = 0x80 + StateType TXType = 0x90 + AgencyType TXType = 0xb0 + PublishType TXType = 0xd0 + InvocationType TXType = 0xd1 ) // String implements the stringer interface. -func (t TransactionType) String() string { +func (t TXType) String() string { switch t { case MinerType: return "miner transaction" diff --git a/pkg/core/util.go b/pkg/core/util.go new file mode 100644 index 000000000..701ac34ac --- /dev/null +++ b/pkg/core/util.go @@ -0,0 +1,12 @@ +package core + +import "github.com/CityOfZion/neo-go/pkg/util" + +// Utilities for quick bootstrapping blockchains. Normally we should +// create the genisis block. For now (to speed up development) we will add +// The hashes manually. + +func GenesisHashPrivNet() util.Uint256 { + hash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099") + return hash +} diff --git a/pkg/crypto/base58.go b/pkg/crypto/base58.go index 07c4aeb47..012670ce9 100644 --- a/pkg/crypto/base58.go +++ b/pkg/crypto/base58.go @@ -79,6 +79,7 @@ func Base58Encode(bytes []byte) string { return encoded } +// Base58CheckDecode decodes the given string. func Base58CheckDecode(s string) (b []byte, err error) { b, err = Base58Decode(s) if err != nil { diff --git a/pkg/network/message.go b/pkg/network/message.go index ae484f58a..f6b59a676 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -61,26 +61,28 @@ type Message struct { Payload payload.Payload } -type commandType string +// CommandType represents the type of a message command. +type CommandType string // valid commands used to send between nodes. const ( - cmdVersion commandType = "version" - cmdVerack = "verack" - cmdGetAddr = "getaddr" - cmdAddr = "addr" - cmdGetHeaders = "getheaders" - cmdHeaders = "headers" - cmdGetBlocks = "getblocks" - cmdInv = "inv" - cmdGetData = "getdata" - cmdBlock = "block" - cmdTX = "tx" - cmdConsensus = "consensus" - cmdUnknown = "unknown" + CMDVersion CommandType = "version" + CMDVerack CommandType = "verack" + CMDGetAddr CommandType = "getaddr" + CMDAddr CommandType = "addr" + CMDGetHeaders CommandType = "getheaders" + CMDHeaders CommandType = "headers" + CMDGetBlocks CommandType = "getblocks" + CMDInv CommandType = "inv" + CMDGetData CommandType = "getdata" + CMDBlock CommandType = "block" + CMDTX CommandType = "tx" + CMDConsensus CommandType = "consensus" + CMDUnknown CommandType = "unknown" ) -func newMessage(magic NetMode, cmd commandType, p payload.Payload) *Message { +// NewMessage returns a new message with the given payload. +func NewMessage(magic NetMode, cmd CommandType, p payload.Payload) *Message { var ( size uint32 checksum []byte @@ -106,36 +108,36 @@ func newMessage(magic NetMode, cmd commandType, p payload.Payload) *Message { } } -// Converts the 12 byte command slice to a commandType. -func (m *Message) commandType() commandType { +// CommandType converts the 12 byte command slice to a CommandType. +func (m *Message) CommandType() CommandType { cmd := cmdByteArrayToString(m.Command) switch cmd { case "version": - return cmdVersion + return CMDVersion case "verack": - return cmdVerack + return CMDVerack case "getaddr": - return cmdGetAddr + return CMDGetAddr case "addr": - return cmdAddr + return CMDAddr case "getheaders": - return cmdGetHeaders + return CMDGetHeaders case "headers": - return cmdHeaders + return CMDHeaders case "getblocks": - return cmdGetBlocks + return CMDGetBlocks case "inv": - return cmdInv + return CMDInv case "getdata": - return cmdGetData + return CMDGetData case "block": - return cmdBlock + return CMDBlock case "tx": - return cmdTX + return CMDTX case "consensus": - return cmdConsensus + return CMDConsensus default: - return cmdUnknown + return CMDUnknown } } @@ -185,36 +187,35 @@ func (m *Message) decodePayload(r io.Reader) error { return errChecksumMismatch } - //r = bytes.NewReader(buf) r = buf var p payload.Payload - switch m.commandType() { - case cmdVersion: + switch m.CommandType() { + case CMDVersion: p = &payload.Version{} if err := p.DecodeBinary(r); err != nil { return err } - case cmdInv: + case CMDInv: p = &payload.Inventory{} if err := p.DecodeBinary(r); err != nil { return err } - case cmdAddr: + case CMDAddr: p = &payload.AddressList{} if err := p.DecodeBinary(r); err != nil { return err } - case cmdBlock: + case CMDBlock: p = &core.Block{} if err := p.DecodeBinary(r); err != nil { return err } - case cmdGetHeaders: + case CMDGetHeaders: p = &payload.GetBlocks{} if err := p.DecodeBinary(r); err != nil { return err } - case cmdHeaders: + case CMDHeaders: p = &payload.Headers{} if err := p.DecodeBinary(r); err != nil { return err @@ -242,7 +243,7 @@ func (m *Message) encode(w io.Writer) error { // convert a command (string) to a byte slice filled with 0 bytes till // size 12. -func cmdToByteArray(cmd commandType) [cmdSize]byte { +func cmdToByteArray(cmd CommandType) [cmdSize]byte { cmdLen := len(cmd) if cmdLen > cmdSize { panic("exceeded command max length of size 12") diff --git a/pkg/network/message_test.go b/pkg/network/message_test.go index 6b1c5b596..6378e7795 100644 --- a/pkg/network/message_test.go +++ b/pkg/network/message_test.go @@ -2,39 +2,33 @@ package network import ( "bytes" - "reflect" "testing" "github.com/CityOfZion/neo-go/pkg/network/payload" + "github.com/stretchr/testify/assert" ) func TestMessageEncodeDecode(t *testing.T) { - m := newMessage(ModeTestNet, cmdVersion, nil) + m := NewMessage(ModeTestNet, CMDVersion, nil) buf := &bytes.Buffer{} if err := m.encode(buf); err != nil { t.Error(err) } - - if n := len(buf.Bytes()); n < minMessageSize { - t.Fatalf("message should be at least %d bytes got %d", minMessageSize, n) - } - if n := len(buf.Bytes()); n > minMessageSize { - t.Fatalf("message without a payload should be exact %d bytes got %d", minMessageSize, n) - } + assert.Equal(t, len(buf.Bytes()), minMessageSize) md := &Message{} if err := md.decode(buf); err != nil { t.Error(err) } - if !reflect.DeepEqual(m, md) { - t.Errorf("both messages should be equal: %v != %v", m, md) - } + assert.Equal(t, m, md) } func TestMessageEncodeDecodeWithVersion(t *testing.T) { - p := payload.NewVersion(12227, 2000, "/neo:2.6.0/", 0, true) - m := newMessage(ModeTestNet, cmdVersion, p) + var ( + p = payload.NewVersion(12227, 2000, "/neo:2.6.0/", 0, true) + m = NewMessage(ModeTestNet, CMDVersion, p) + ) buf := new(bytes.Buffer) if err := m.encode(buf); err != nil { @@ -45,15 +39,14 @@ func TestMessageEncodeDecodeWithVersion(t *testing.T) { if err := mDecode.decode(buf); err != nil { t.Fatal(err) } - - if !reflect.DeepEqual(m, mDecode) { - t.Fatalf("expected both messages to be equal %v and %v", m, mDecode) - } + assert.Equal(t, m, mDecode) } func TestMessageInvalidChecksum(t *testing.T) { - p := payload.NewVersion(1111, 3000, "/NEO:2.6.0/", 0, true) - m := newMessage(ModeTestNet, cmdVersion, p) + var ( + p = payload.NewVersion(1111, 3000, "/NEO:2.6.0/", 0, true) + m = NewMessage(ModeTestNet, CMDVersion, p) + ) m.Checksum = 1337 buf := new(bytes.Buffer) diff --git a/pkg/network/node.go b/pkg/network/node.go new file mode 100644 index 000000000..88a0c3cb8 --- /dev/null +++ b/pkg/network/node.go @@ -0,0 +1,219 @@ +package network + +import ( + "errors" + "fmt" + "os" + "time" + + "github.com/CityOfZion/neo-go/pkg/core" + "github.com/CityOfZion/neo-go/pkg/network/payload" + "github.com/CityOfZion/neo-go/pkg/util" + log "github.com/go-kit/kit/log" +) + +const ( + protoVersion = 0 +) + +var protoTickInterval = 5 * time.Second + +// Node represents the local node. +type Node struct { + // Config fields may not be modified while the server is running. + Config + + logger log.Logger + server *Server + services uint64 + bc *core.Blockchain + protoIn chan messageTuple +} + +// messageTuple respresents a tuple that holds the message being +// send along with its peer. +type messageTuple struct { + peer Peer + msg *Message +} + +func newNode(s *Server, cfg Config) *Node { + var startHash util.Uint256 + if cfg.Net == ModePrivNet { + startHash = core.GenesisHashPrivNet() + } + + bc := core.NewBlockchain( + core.NewMemoryStore(), + startHash, + ) + + logger := log.NewLogfmtLogger(os.Stderr) + logger = log.With(logger, "component", "node") + + n := &Node{ + Config: cfg, + protoIn: make(chan messageTuple), + server: s, + bc: bc, + logger: logger, + } + go n.handleMessages() + + return n +} + +func (n *Node) version() *payload.Version { + return payload.NewVersion(n.server.id, n.ListenTCP, n.UserAgent, 1, n.Relay) +} + +func (n *Node) startProtocol(peer Peer) { + ticker := time.NewTicker(protoTickInterval).C + + for { + select { + case <-ticker: + // Try to sync with the peer if his block height is higher then ours. + if peer.Version().StartHeight > n.bc.HeaderHeight() { + n.askMoreHeaders(peer) + } + // Only ask for more peers if the server has the capacity for it. + if n.server.hasCapacity() { + msg := NewMessage(n.Net, CMDGetAddr, nil) + peer.Send(msg) + } + case <-peer.Done(): + return + } + } +} + +// When a peer sends out his version we reply with verack after validating +// the version. +func (n *Node) handleVersionCmd(version *payload.Version, peer Peer) error { + msg := NewMessage(n.Net, CMDVerack, nil) + peer.Send(msg) + return nil +} + +// handleInvCmd handles the forwarded inventory received from the peer. +// We will use the getdata message to get more details about the received +// inventory. +// note: if the server has Relay on false, inventory messages are not received. +func (n *Node) handleInvCmd(inv *payload.Inventory, peer Peer) error { + if !inv.Type.Valid() { + return fmt.Errorf("invalid inventory type received: %s", inv.Type) + } + if len(inv.Hashes) == 0 { + return errors.New("inventory has no hashes") + } + payload := payload.NewInventory(inv.Type, inv.Hashes) + peer.Send(NewMessage(n.Net, CMDGetData, payload)) + return nil +} + +// handleBlockCmd processes the received block received from its peer. +func (n *Node) handleBlockCmd(block *core.Block, peer Peer) error { + n.logger.Log( + "event", "block received", + "index", block.Index, + "hash", block.Hash(), + "tx", len(block.Transactions), + ) + + return n.bc.AddBlock(block) +} + +// After a node sends out the getaddr message its receives a list of known peers +// in the network. handleAddrCmd processes that payload. +func (n *Node) handleAddrCmd(addressList *payload.AddressList, peer Peer) error { + addrs := make([]string, len(addressList.Addrs)) + for i := 0; i < len(addrs); i++ { + addrs[i] = addressList.Addrs[i].Address.String() + } + n.server.connectToPeers(addrs...) + return nil +} + +// The handleHeadersCmd will process the received headers from its peer. +// We call this in a routine cause we may block Peers Send() for to long. +func (n *Node) handleHeadersCmd(headers *payload.Headers, peer Peer) error { + go func(headers []*core.Header) { + if err := n.bc.AddHeaders(headers...); err != nil { + n.logger.Log("msg", "failed processing headers", "err", err) + return + } + // The peer will respond with a maximum of 2000 headers in one batch. + // We will ask one more batch here if needed. Eventually we will get synced + // due to the startProtocol routine that will ask headers every protoTick. + if n.bc.HeaderHeight() < peer.Version().StartHeight { + n.askMoreHeaders(peer) + } + }(headers.Hdrs) + + return nil +} + +// askMoreHeaders will send a getheaders message to the peer. +func (n *Node) askMoreHeaders(p Peer) { + start := []util.Uint256{n.bc.CurrentHeaderHash()} + payload := payload.NewGetBlocks(start, util.Uint256{}) + p.Send(NewMessage(n.Net, CMDGetHeaders, payload)) +} + +// blockhain implements the Noder interface. +func (n *Node) blockchain() *core.Blockchain { return n.bc } + +// handleProto implements the protoHandler interface. +func (n *Node) handleProto(msg *Message, p Peer) { + n.protoIn <- messageTuple{ + msg: msg, + peer: p, + } +} + +func (n *Node) handleMessages() { + for { + t := <-n.protoIn + + var ( + msg = t.msg + p = t.peer + err error + ) + + switch msg.CommandType() { + case CMDVersion: + version := msg.Payload.(*payload.Version) + err = n.handleVersionCmd(version, p) + case CMDAddr: + addressList := msg.Payload.(*payload.AddressList) + err = n.handleAddrCmd(addressList, p) + case CMDInv: + inventory := msg.Payload.(*payload.Inventory) + err = n.handleInvCmd(inventory, p) + case CMDBlock: + block := msg.Payload.(*core.Block) + err = n.handleBlockCmd(block, p) + case CMDHeaders: + headers := msg.Payload.(*payload.Headers) + err = n.handleHeadersCmd(headers, p) + case CMDVerack: + // Only start the protocol if we got the version and verack + // received. + if p.Version() != nil { + go n.startProtocol(p) + } + case CMDUnknown: + err = errors.New("received non-protocol messgae") + } + + if err != nil { + n.logger.Log( + "msg", "failed processing message", + "command", msg.CommandType, + "err", err, + ) + } + } +} diff --git a/pkg/network/node_test.go b/pkg/network/node_test.go new file mode 100644 index 000000000..37b217781 --- /dev/null +++ b/pkg/network/node_test.go @@ -0,0 +1,7 @@ +package network + +import "testing" + +func TestHandleVersion(t *testing.T) { + +} diff --git a/pkg/network/payload/addr.go b/pkg/network/payload/addr.go deleted file mode 100644 index 3feec11c7..000000000 --- a/pkg/network/payload/addr.go +++ /dev/null @@ -1,90 +0,0 @@ -package payload - -import ( - "encoding/binary" - "io" - - "github.com/CityOfZion/neo-go/pkg/util" -) - -// AddrWithTime payload -type AddrWithTime struct { - // Timestamp the node connected to the network. - Timestamp uint32 - Services uint64 - Addr util.Endpoint -} - -// NewAddrWithTime return a pointer to AddrWithTime. -func NewAddrWithTime(addr util.Endpoint) *AddrWithTime { - return &AddrWithTime{ - Timestamp: 1337, - Services: 1, - Addr: addr, - } -} - -// Size implements the payload interface. -func (p *AddrWithTime) Size() uint32 { - return 30 -} - -// DecodeBinary implements the Payload interface. -func (p *AddrWithTime) DecodeBinary(r io.Reader) error { - err := binary.Read(r, binary.LittleEndian, &p.Timestamp) - err = binary.Read(r, binary.LittleEndian, &p.Services) - err = binary.Read(r, binary.BigEndian, &p.Addr.IP) - err = binary.Read(r, binary.BigEndian, &p.Addr.Port) - - return err -} - -// EncodeBinary implements the Payload interface. -func (p *AddrWithTime) EncodeBinary(w io.Writer) error { - err := binary.Write(w, binary.LittleEndian, p.Timestamp) - err = binary.Write(w, binary.LittleEndian, p.Services) - err = binary.Write(w, binary.BigEndian, p.Addr.IP) - err = binary.Write(w, binary.BigEndian, p.Addr.Port) - - return err -} - -// AddressList is a list with AddrWithTime. -type AddressList struct { - Addrs []*AddrWithTime -} - -// DecodeBinary implements the Payload interface. -func (p *AddressList) DecodeBinary(r io.Reader) error { - listLen := util.ReadVarUint(r) - - p.Addrs = make([]*AddrWithTime, listLen) - for i := 0; i < int(listLen); i++ { - addr := &AddrWithTime{} - if err := addr.DecodeBinary(r); err != nil { - return err - } - p.Addrs[i] = addr - } - - return nil -} - -// EncodeBinary implements the Payload interface. -func (p *AddressList) EncodeBinary(w io.Writer) error { - // Write the length of the slice - util.WriteVarUint(w, uint64(len(p.Addrs))) - - for _, addr := range p.Addrs { - if err := addr.EncodeBinary(w); err != nil { - return err - } - } - - return nil -} - -// Size implements the Payloader interface. -func (p *AddressList) Size() uint32 { - return uint32(len(p.Addrs) * 30) -} diff --git a/pkg/network/payload/addr_test.go b/pkg/network/payload/addr_test.go deleted file mode 100644 index 656d5bfbd..000000000 --- a/pkg/network/payload/addr_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package payload - -import ( - "bytes" - "fmt" - "reflect" - "testing" - - "github.com/CityOfZion/neo-go/pkg/util" -) - -func TestEncodeDecodeAddr(t *testing.T) { - e, err := util.EndpointFromString("127.0.0.1:2000") - if err != nil { - t.Fatal(err) - } - - addr := NewAddrWithTime(e) - buf := new(bytes.Buffer) - if err := addr.EncodeBinary(buf); err != nil { - t.Fatal(err) - } - - addrDecode := &AddrWithTime{} - if err := addrDecode.DecodeBinary(buf); err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(addr, addrDecode) { - t.Fatalf("expected both addr payloads to be equal: %v and %v", addr, addrDecode) - } -} - -func TestEncodeDecodeAddressList(t *testing.T) { - var lenList uint8 = 4 - addrList := &AddressList{make([]*AddrWithTime, lenList)} - for i := 0; i < int(lenList); i++ { - e, _ := util.EndpointFromString(fmt.Sprintf("127.0.0.1:200%d", i)) - addrList.Addrs[i] = NewAddrWithTime(e) - } - - buf := new(bytes.Buffer) - if err := addrList.EncodeBinary(buf); err != nil { - t.Fatal(err) - } - - addrListDecode := &AddressList{} - if err := addrListDecode.DecodeBinary(buf); err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(addrList, addrListDecode) { - t.Fatalf("expected both address list payloads to be equal: %v and %v", addrList, addrListDecode) - } -} diff --git a/pkg/network/payload/address.go b/pkg/network/payload/address.go new file mode 100644 index 000000000..da237f8d6 --- /dev/null +++ b/pkg/network/payload/address.go @@ -0,0 +1,85 @@ +package payload + +import ( + "encoding/binary" + "io" + "time" + + "github.com/CityOfZion/neo-go/pkg/util" +) + +// AddressAndTime payload. +type AddressAndTime struct { + Timestamp uint32 + Services uint64 + Address util.Endpoint +} + +// NewAddressAndTime creates a new AddressAndTime object. +func NewAddressAndTime(e util.Endpoint, t time.Time) *AddressAndTime { + return &AddressAndTime{ + Timestamp: uint32(t.UTC().Unix()), + Services: 1, + Address: e, + } +} + +// DecodeBinary implements the Payload interface. +func (p *AddressAndTime) DecodeBinary(r io.Reader) error { + if err := binary.Read(r, binary.LittleEndian, &p.Timestamp); err != nil { + return err + } + if err := binary.Read(r, binary.LittleEndian, &p.Services); err != nil { + return err + } + if err := binary.Read(r, binary.BigEndian, &p.Address.IP); err != nil { + return err + } + return binary.Read(r, binary.BigEndian, &p.Address.Port) +} + +// EncodeBinary implements the Payload interface. +func (p *AddressAndTime) EncodeBinary(w io.Writer) error { + if err := binary.Write(w, binary.LittleEndian, p.Timestamp); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, p.Services); err != nil { + return err + } + if err := binary.Write(w, binary.BigEndian, p.Address.IP); err != nil { + return err + } + return binary.Write(w, binary.BigEndian, p.Address.Port) +} + +// AddressList is a list with AddrAndTime. +type AddressList struct { + Addrs []*AddressAndTime +} + +// DecodeBinary implements the Payload interface. +func (p *AddressList) DecodeBinary(r io.Reader) error { + listLen := util.ReadVarUint(r) + + p.Addrs = make([]*AddressAndTime, listLen) + for i := 0; i < int(listLen); i++ { + p.Addrs[i] = &AddressAndTime{} + if err := p.Addrs[i].DecodeBinary(r); err != nil { + return err + } + } + return nil +} + +// EncodeBinary implements the Payload interface. +func (p *AddressList) EncodeBinary(w io.Writer) error { + if err := util.WriteVarUint(w, uint64(len(p.Addrs))); err != nil { + return err + } + for _, addr := range p.Addrs { + if err := addr.EncodeBinary(w); err != nil { + return err + } + } + return nil +} diff --git a/pkg/network/payload/address_test.go b/pkg/network/payload/address_test.go new file mode 100644 index 000000000..c2b95780a --- /dev/null +++ b/pkg/network/payload/address_test.go @@ -0,0 +1,51 @@ +package payload + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/CityOfZion/neo-go/pkg/util" + "github.com/stretchr/testify/assert" +) + +func TestEncodeDecodeAddress(t *testing.T) { + var ( + e = util.NewEndpoint("127.0.0.1:2000") + addr = NewAddressAndTime(e, time.Now()) + buf = new(bytes.Buffer) + ) + + if err := addr.EncodeBinary(buf); err != nil { + t.Fatal(err) + } + + addrDecode := &AddressAndTime{} + if err := addrDecode.DecodeBinary(buf); err != nil { + t.Fatal(err) + } + + assert.Equal(t, addr, addrDecode) +} + +func TestEncodeDecodeAddressList(t *testing.T) { + var lenList uint8 = 4 + addrList := &AddressList{make([]*AddressAndTime, lenList)} + for i := 0; i < int(lenList); i++ { + e := util.NewEndpoint(fmt.Sprintf("127.0.0.1:200%d", i)) + addrList.Addrs[i] = NewAddressAndTime(e, time.Now()) + } + + buf := new(bytes.Buffer) + if err := addrList.EncodeBinary(buf); err != nil { + t.Fatal(err) + } + + addrListDecode := &AddressList{} + if err := addrListDecode.DecodeBinary(buf); err != nil { + t.Fatal(err) + } + + assert.Equal(t, addrList, addrListDecode) +} diff --git a/pkg/network/payload/headers.go b/pkg/network/payload/headers.go index 373b149da..6350239ba 100644 --- a/pkg/network/payload/headers.go +++ b/pkg/network/payload/headers.go @@ -37,6 +37,5 @@ func (p *Headers) EncodeBinary(w io.Writer) error { return err } } - return nil } diff --git a/pkg/network/payload/headers_test.go b/pkg/network/payload/headers_test.go index a2848f984..43d6169f3 100644 --- a/pkg/network/payload/headers_test.go +++ b/pkg/network/payload/headers_test.go @@ -2,11 +2,11 @@ package payload import ( "bytes" - "reflect" "testing" "github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core/transaction" + "github.com/stretchr/testify/assert" ) func TestHeadersEncodeDecode(t *testing.T) { @@ -50,7 +50,9 @@ func TestHeadersEncodeDecode(t *testing.T) { t.Fatal(err) } - if !reflect.DeepEqual(headers, headersDecode) { - t.Fatalf("expected both header payload to be equal %+v and %+v", headers, headersDecode) + for i := 0; i < len(headers.Hdrs); i++ { + assert.Equal(t, headers.Hdrs[i].Version, headersDecode.Hdrs[i].Version) + assert.Equal(t, headers.Hdrs[i].Index, headersDecode.Hdrs[i].Index) + assert.Equal(t, headers.Hdrs[i].Script, headersDecode.Hdrs[i].Script) } } diff --git a/pkg/network/payload/inventory.go b/pkg/network/payload/inventory.go index f2c27aa44..502c255f8 100644 --- a/pkg/network/payload/inventory.go +++ b/pkg/network/payload/inventory.go @@ -35,8 +35,8 @@ func (i InventoryType) Valid() bool { // List of valid InventoryTypes. const ( BlockType InventoryType = 0x01 // 1 - TXType = 0x02 // 2 - ConsensusType = 0xe0 // 224 + TXType InventoryType = 0x02 // 2 + ConsensusType InventoryType = 0xe0 // 224 ) // Inventory payload diff --git a/pkg/network/payload/version.go b/pkg/network/payload/version.go index 5f8d4d0c8..42d69f451 100644 --- a/pkg/network/payload/version.go +++ b/pkg/network/payload/version.go @@ -44,36 +44,63 @@ func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version { // DecodeBinary implements the Payload interface. func (p *Version) DecodeBinary(r io.Reader) error { - err := binary.Read(r, binary.LittleEndian, &p.Version) - err = binary.Read(r, binary.LittleEndian, &p.Services) - err = binary.Read(r, binary.LittleEndian, &p.Timestamp) - err = binary.Read(r, binary.LittleEndian, &p.Port) - err = binary.Read(r, binary.LittleEndian, &p.Nonce) + if err := binary.Read(r, binary.LittleEndian, &p.Version); err != nil { + return err + } + if err := binary.Read(r, binary.LittleEndian, &p.Services); err != nil { + return err + } + if err := binary.Read(r, binary.LittleEndian, &p.Timestamp); err != nil { + return err + } + if err := binary.Read(r, binary.LittleEndian, &p.Port); err != nil { + return err + } + if err := binary.Read(r, binary.LittleEndian, &p.Nonce); err != nil { + return err + } var lenUA uint8 - err = binary.Read(r, binary.LittleEndian, &lenUA) + if err := binary.Read(r, binary.LittleEndian, &lenUA); err != nil { + return err + } p.UserAgent = make([]byte, lenUA) - err = binary.Read(r, binary.LittleEndian, &p.UserAgent) - - err = binary.Read(r, binary.LittleEndian, &p.StartHeight) - err = binary.Read(r, binary.LittleEndian, &p.Relay) - - return err + if err := binary.Read(r, binary.LittleEndian, &p.UserAgent); err != nil { + return err + } + if err := binary.Read(r, binary.LittleEndian, &p.StartHeight); err != nil { + return err + } + return binary.Read(r, binary.LittleEndian, &p.Relay) } // EncodeBinary implements the Payload interface. func (p *Version) EncodeBinary(w io.Writer) error { - err := binary.Write(w, binary.LittleEndian, p.Version) - err = binary.Write(w, binary.LittleEndian, p.Services) - err = binary.Write(w, binary.LittleEndian, p.Timestamp) - err = binary.Write(w, binary.LittleEndian, p.Port) - err = binary.Write(w, binary.LittleEndian, p.Nonce) - err = binary.Write(w, binary.LittleEndian, uint8(len(p.UserAgent))) - err = binary.Write(w, binary.LittleEndian, p.UserAgent) - err = binary.Write(w, binary.LittleEndian, p.StartHeight) - err = binary.Write(w, binary.LittleEndian, p.Relay) - - return err + if err := binary.Write(w, binary.LittleEndian, p.Version); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, p.Services); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, p.Timestamp); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, p.Port); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, p.Nonce); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, uint8(len(p.UserAgent))); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, p.UserAgent); err != nil { + return err + } + if err := binary.Write(w, binary.LittleEndian, p.StartHeight); err != nil { + return err + } + return binary.Write(w, binary.LittleEndian, p.Relay) } // Size implements the payloader interface. diff --git a/pkg/network/peer.go b/pkg/network/peer.go index f492d565c..f99668ed8 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -5,48 +5,12 @@ import ( "github.com/CityOfZion/neo-go/pkg/util" ) -// Peer is the local representation of a remote node. It's an interface that may -// be backed by any concrete transport: local, HTTP, tcp. +// A Peer is the local representation of a remote peer. +// It's an interface that may be backed by any concrete +// transport. type Peer interface { - id() uint32 - addr() util.Endpoint - disconnect() - Send(*Message) error - version() *payload.Version + Version() *payload.Version + Endpoint() util.Endpoint + Send(*Message) + Done() chan struct{} } - -// LocalPeer is the simplest kind of peer, mapped to a server in the -// same process-space. -type LocalPeer struct { - s *Server - nonce uint32 - endpoint util.Endpoint - pVersion *payload.Version -} - -// NewLocalPeer return a LocalPeer. -func NewLocalPeer(s *Server) *LocalPeer { - e, _ := util.EndpointFromString("1.1.1.1:1111") - return &LocalPeer{endpoint: e, s: s} -} - -func (p *LocalPeer) Send(msg *Message) error { - switch msg.commandType() { - case cmdVersion: - version := msg.Payload.(*payload.Version) - return p.s.handleVersionCmd(version, p) - case cmdGetAddr: - return p.s.handleGetaddrCmd(msg, p) - default: - return nil - } -} - -// Version implements the Peer interface. -func (p *LocalPeer) version() *payload.Version { - return p.pVersion -} - -func (p *LocalPeer) id() uint32 { return p.nonce } -func (p *LocalPeer) addr() util.Endpoint { return p.endpoint } -func (p *LocalPeer) disconnect() {} diff --git a/pkg/network/protocol.go b/pkg/network/protocol.go new file mode 100644 index 000000000..f81d5b003 --- /dev/null +++ b/pkg/network/protocol.go @@ -0,0 +1,22 @@ +package network + +import ( + "github.com/CityOfZion/neo-go/pkg/core" + "github.com/CityOfZion/neo-go/pkg/network/payload" +) + +// A ProtoHandler is an interface that abstract the implementation +// of the NEO protocol. +type ProtoHandler interface { + version() *payload.Version + handleProto(*Message, Peer) +} + +type protoHandleFunc func(*Message, Peer) + +// Noder is anything that implements the NEO protocol +// and can return the Blockchain object. +type Noder interface { + ProtoHandler + blockchain() *core.Blockchain +} diff --git a/pkg/network/rpc.go b/pkg/network/rpc.go deleted file mode 100644 index c1156fa65..000000000 --- a/pkg/network/rpc.go +++ /dev/null @@ -1,129 +0,0 @@ -package network - -import ( - "encoding/json" - "fmt" - "net/http" -) - -const ( - rpcPortMainNet = 20332 - rpcPortTestNet = 10332 - rpcVersion = "2.0" - - // error response messages - methodNotFound = "Method not found" - parseError = "Parse error" -) - -// Each NEO node has a set of optional APIs for accessing blockchain -// data and making things easier for development of blockchain apps. -// APIs are provided via JSON-RPC , comm at bottom layer is with http/https protocol. - -// listenHTTP creates an ingress bridge from the outside world to the passed -// server, by installing handlers for all the necessary RPCs to the passed mux. -func listenHTTP(s *Server, port int) { - api := &API{s} - p := fmt.Sprintf(":%d", port) - s.logger.Printf("serving RPC on %d", port) - s.logger.Printf("%s", http.ListenAndServe(p, api)) -} - -// API serves JSON-RPC. -type API struct { - s *Server -} - -func (s *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Official nodes respond a parse error if the method is not POST. - // Instead of returning a decent response for this, let's do the same. - if r.Method != "POST" { - writeError(w, 0, 0, parseError) - } - - var req Request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeError(w, 0, 0, parseError) - return - } - defer r.Body.Close() - - if req.Version != rpcVersion { - writeJSON(w, http.StatusBadRequest, nil) - return - } - - switch req.Method { - case "getconnectioncount": - if err := s.getConnectionCount(w, &req); err != nil { - writeError(w, 0, 0, parseError) - return - } - case "getblockcount": - case "getbestblockhash": - default: - writeError(w, 0, 0, methodNotFound) - } -} - -// This is an Example on how we could handle incomming RPC requests. -func (s *API) getConnectionCount(w http.ResponseWriter, req *Request) error { - count := s.s.peerCount() - - resp := ConnectionCountResponse{ - Version: rpcVersion, - Result: count, - ID: 1, - } - - return writeJSON(w, http.StatusOK, resp) -} - -// writeError returns a JSON error with given parameters. All error HTTP -// status codes are 200. According to the official API. -func writeError(w http.ResponseWriter, id, code int, msg string) error { - resp := RequestError{ - Version: rpcVersion, - ID: id, - Error: Error{ - Code: code, - Message: msg, - }, - } - - return writeJSON(w, http.StatusOK, resp) -} - -func writeJSON(w http.ResponseWriter, status int, v interface{}) error { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - return json.NewEncoder(w).Encode(v) -} - -// Request is an object received through JSON-RPC from the client. -type Request struct { - Version string `json:"jsonrpc"` - Method string `json:"method"` - Params []string `json:"params"` - ID int `json:"id"` -} - -// ConnectionCountResponse .. -type ConnectionCountResponse struct { - Version string `json:"jsonrpc"` - Result int `json:"result"` - ID int `json:"id"` -} - -// RequestError .. -type RequestError struct { - Version string `json:"jsonrpc"` - ID int `json:"id"` - Error Error `json:"error"` -} - -// Error holds information about an RCP error. -type Error struct { - Code int `json:"code"` - Message string `json:"message"` -} diff --git a/pkg/network/server.go b/pkg/network/server.go index af9986d1d..8c8367fbe 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -1,354 +1,282 @@ package network import ( - "context" - "errors" "fmt" - "log" "net" "os" - "sync" + "text/tabwriter" "time" - "github.com/CityOfZion/neo-go/pkg/core" - "github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/util" + log "github.com/go-kit/kit/log" ) const ( // node version version = "2.6.0" + // official ports according to the protocol. portMainNet = 10333 portTestNet = 20333 maxPeers = 50 ) -type messageTuple struct { - peer Peer - msg *Message +var dialTimeout = 4 * time.Second + +// Config holds the server configuration. +type Config struct { + // MaxPeers it the maximum numbers of peers that can + // be connected to the server. + MaxPeers int + + // The user agent of the server. + UserAgent string + + // The listen address of the TCP server. + ListenTCP uint16 + + // The listen address of the RPC server. + ListenRPC uint16 + + // The network mode this server will operate on. + // ModePrivNet docker private network. + // ModeTestNet NEO test network. + // ModeMainNet NEO main network. + Net NetMode + + // Relay determins whether the server is forwarding its inventory. + Relay bool + + // Seeds are a list of initial nodes used to establish connectivity. + Seeds []string + + // Maximum duration a single dial may take. + DialTimeout time.Duration } -// Server is the representation of a full working NEO TCP node. +// Server manages all incoming peer connections. type Server struct { - logger *log.Logger - // id of the server + // Config fields may not be modified while the server is running. + Config + + // Proto is just about anything that can handle the NEO protocol. + // In production enviroments the ProtoHandler is mostly the local node. + proto ProtoHandler + + // Unique id of this server. id uint32 - // the port the TCP listener is listening on. - port uint16 - // userAgent of the server. - userAgent string - // The "magic" mode the server is currently running on. - // This can either be 0x00746e41 or 0x74746e41 for main or test net. - // Or 56753 to work with the docker privnet. - net NetMode - // map that holds all connected peers to this server. - peers map[Peer]bool - // channel for handling new registerd peers. - register chan Peer - // channel for safely removing and disconnecting peers. - unregister chan Peer - // channel for coordinating messages. - message chan messageTuple - // channel used to gracefull shutdown the server. - quit chan struct{} - // Whether this server will receive and forward messages. - relay bool - // TCP listener of the server + + logger log.Logger listener net.Listener - // channel for safely responding the number of current connected peers. - peerCountCh chan peerCount - // a list of hashes that - knownHashes protectedHashmap - // The blockchain. - bc *core.Blockchain + + register chan Peer + unregister chan Peer + + badAddrOp chan func(map[string]bool) + badAddrOpDone chan struct{} + + peerOp chan func(map[Peer]bool) + peerOpDone chan struct{} + + quit chan struct{} } -// TODO: Maybe util is a better place for such data types. -type protectedHashmap struct { - *sync.RWMutex - hashes map[util.Uint256]bool -} - -func (m protectedHashmap) add(h util.Uint256) bool { - m.Lock() - defer m.Unlock() - - if _, ok := m.hashes[h]; !ok { - m.hashes[h] = true - return true +// NewServer returns a new Server object created from the +// given config. +func NewServer(cfg Config) *Server { + if cfg.MaxPeers == 0 { + cfg.MaxPeers = maxPeers } - return false -} - -func (m protectedHashmap) remove(h util.Uint256) bool { - m.Lock() - defer m.Unlock() - - if _, ok := m.hashes[h]; ok { - delete(m.hashes, h) - return true + if cfg.Net == 0 { + cfg.Net = ModeTestNet } - return false -} - -func (m protectedHashmap) has(h util.Uint256) bool { - m.RLock() - defer m.RUnlock() - - _, ok := m.hashes[h] - - return ok -} - -// NewServer returns a pointer to a new server. -func NewServer(net NetMode) *Server { - logger := log.New(os.Stdout, "[NEO SERVER] :: ", 0) - - if net != ModeTestNet && net != ModeMainNet && net != ModePrivNet { - logger.Fatalf("invalid network mode %d", net) + if cfg.DialTimeout == 0 { + cfg.DialTimeout = dialTimeout } - // For now I will hard code a genesis block of the docker privnet container. - startHash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099") + logger := log.NewLogfmtLogger(os.Stderr) + logger = log.With(logger, "component", "server") s := &Server{ - id: util.RandUint32(1111111, 9999999), - userAgent: fmt.Sprintf("/NEO:%s/", version), - logger: logger, - peers: make(map[Peer]bool), - register: make(chan Peer), - unregister: make(chan Peer), - message: make(chan messageTuple), - relay: true, // currently relay is not handled. - net: net, - quit: make(chan struct{}), - peerCountCh: make(chan peerCount), - bc: core.NewBlockchain(core.NewMemoryStore(), logger, startHash), + Config: cfg, + logger: logger, + id: util.RandUint32(1000000, 9999999), + quit: make(chan struct{}, 1), + register: make(chan Peer), + unregister: make(chan Peer), + badAddrOp: make(chan func(map[string]bool)), + badAddrOpDone: make(chan struct{}), + peerOp: make(chan func(map[Peer]bool)), + peerOpDone: make(chan struct{}), } + s.proto = newNode(s, cfg) + return s } -// Start run's the server. -// TODO: server should be initialized with a config. -func (s *Server) Start(opts StartOpts) { - s.port = uint16(opts.TCP) - - fmt.Println(logo()) - fmt.Println(string(s.userAgent)) - fmt.Println("") - s.logger.Printf("NET: %s - TCP: %d - RELAY: %v - ID: %d", - s.net, int(s.port), s.relay, s.id) - - go listenTCP(s, opts.TCP) - - if opts.RPC > 0 { - go listenHTTP(s, opts.RPC) - } - - if len(opts.Seeds) > 0 { - connectToSeeds(s, opts.Seeds) - } - - s.loop() -} - -// Stop the server, attemping a gracefull shutdown. -func (s *Server) Stop() { s.quit <- struct{}{} } - -// shutdown the server, disconnecting all peers. -func (s *Server) shutdown() { - s.logger.Println("attemping a quitefull shutdown.") - s.listener.Close() - - // disconnect and remove all connected peers. - for peer := range s.peers { - peer.disconnect() - } -} - -func (s *Server) loop() { - for { - select { - // When a new connection is been established, (by this server or remote node) - // its peer will be received on this channel. - // Any peer registration must happen via this channel. - case peer := <-s.register: - if len(s.peers) < maxPeers { - s.logger.Printf("peer registered from address %s", peer.addr()) - s.peers[peer] = true - - if err := s.handlePeerConnected(peer); err != nil { - s.logger.Printf("failed handling peer connection: %s", err) - peer.disconnect() - } - } - - // unregister safely deletes a peer. For disconnecting peers use the - // disconnect() method on the peer, it will call unregister and terminates its routines. - case peer := <-s.unregister: - if _, ok := s.peers[peer]; ok { - delete(s.peers, peer) - s.logger.Printf("peer %s disconnected", peer.addr()) - } - - case t := <-s.peerCountCh: - t.count <- len(s.peers) - - case <-s.quit: - s.shutdown() - } - } -} - -// When a new peer is connected we send our version. -// No further communication should be made before both sides has received -// the versions of eachother. -func (s *Server) handlePeerConnected(p Peer) error { - // TODO: get the blockheight of this server once core implemented this. - payload := payload.NewVersion(s.id, s.port, s.userAgent, s.bc.HeaderHeight(), s.relay) - msg := newMessage(s.net, cmdVersion, payload) - return p.Send(msg) -} - -func (s *Server) handleVersionCmd(version *payload.Version, p Peer) error { - if s.id == version.Nonce { - return errors.New("identical nonce") - } - if p.addr().Port != version.Port { - return fmt.Errorf("port mismatch: %d and %d", version.Port, p.addr().Port) - } - - return p.Send( - newMessage(s.net, cmdVerack, nil), - ) -} - -func (s *Server) handleGetaddrCmd(msg *Message, p Peer) error { - return nil -} - -// The node can broadcast the object information it owns by this message. -// The message can be sent automatically or can be used to answer getbloks messages. -func (s *Server) handleInvCmd(inv *payload.Inventory, p Peer) error { - if !inv.Type.Valid() { - return fmt.Errorf("invalid inventory type %s", inv.Type) - } - if len(inv.Hashes) == 0 { - return errors.New("inventory should have at least 1 hash got 0") - } - - // todo: only grab the hashes that we dont know. - - payload := payload.NewInventory(inv.Type, inv.Hashes) - resp := newMessage(s.net, cmdGetData, payload) - - return p.Send(resp) -} - -// handleBlockCmd processes the received block. -func (s *Server) handleBlockCmd(block *core.Block, p Peer) error { - hash, err := block.Hash() +func (s *Server) createListener() error { + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", s.ListenTCP)) if err != nil { return err } - - s.logger.Printf("new block: index %d hash %s", block.Index, hash) - + s.listener = ln return nil } -// After receiving the getaddr message, the node returns an addr message as response -// and provides information about the known nodes on the network. -func (s *Server) handleAddrCmd(addrList *payload.AddressList, p Peer) error { - for _, addr := range addrList.Addrs { - if !s.peerAlreadyConnected(addr.Addr) { - // TODO: this is not transport abstracted. - go connectToRemoteNode(s, addr.Addr.String()) +func (s *Server) listenTCP() { + for { + conn, err := s.listener.Accept() + if err != nil { + s.logger.Log("msg", "conn read error", "err", err) + break + } + go s.setupConnection(conn) + } + s.Quit() +} + +func (s *Server) setupConnection(conn net.Conn) { + if !s.hasCapacity() { + s.logger.Log("msg", "server reached maximum capacity") + return + } + + p := NewTCPPeer(conn, s.proto.handleProto) + s.register <- p + if err := p.run(); err != nil { + s.unregister <- p + } +} + +func (s *Server) connectToPeers(addrs ...string) { + for _, addr := range addrs { + if s.hasCapacity() && s.canConnectWith(addr) { + go func(addr string) { + conn, err := net.DialTimeout("tcp", addr, s.DialTimeout) + if err != nil { + s.badAddrOp <- func(badAddrs map[string]bool) { + badAddrs[addr] = true + } + <-s.badAddrOpDone + return + } + go s.setupConnection(conn) + }(addr) } } - return nil } -// Handle the headers received from the remote after we asked for headers with the -// "getheaders" message. -func (s *Server) handleHeadersCmd(headers *payload.Headers, p Peer) error { - // Set a deadline for adding headers? - go func(ctx context.Context, headers []*core.Header) { - if err := s.bc.AddHeaders(headers...); err != nil { - s.logger.Printf("failed to add headers: %s", err) - return - } - - // Ask more headers if we are not in sync with the peer. - if s.bc.HeaderHeight() < p.version().StartHeight { - if err := s.askMoreHeaders(p); err != nil { - s.logger.Printf("getheaders RPC failed: %s", err) - return +func (s *Server) canConnectWith(addr string) bool { + canConnect := true + s.peerOp <- func(peers map[Peer]bool) { + for peer := range peers { + if peer.Endpoint().String() == addr { + canConnect = false + break } } - }(context.TODO(), headers.Hdrs) + } + <-s.peerOpDone + if !canConnect { + return false + } - return nil + s.badAddrOp <- func(badAddrs map[string]bool) { + _, ok := badAddrs[addr] + canConnect = !ok + } + <-s.badAddrOpDone + return canConnect } -// Ask the peer for more headers We use the current block hash as start. -func (s *Server) askMoreHeaders(p Peer) error { - start := []util.Uint256{s.bc.CurrentHeaderHash()} - payload := payload.NewGetBlocks(start, util.Uint256{}) - msg := newMessage(s.net, cmdGetHeaders, payload) - - return p.Send(msg) +func (s *Server) hasCapacity() bool { + return s.PeerCount() != s.MaxPeers } -// check if the addr is already connected to the server. -func (s *Server) peerAlreadyConnected(addr net.Addr) bool { - // TODO: Dont try to connect with ourselfs. - for peer := range s.peers { - if peer.addr().String() == addr.String() { - return true +func (s *Server) sendVersion(peer Peer) { + peer.Send(NewMessage(s.Net, CMDVersion, s.proto.version())) +} + +func (s *Server) run() { + var ( + ticker = time.NewTicker(30 * time.Second).C + peers = make(map[Peer]bool) + badAddrs = make(map[string]bool) + ) + + for { + select { + case op := <-s.badAddrOp: + op(badAddrs) + s.badAddrOpDone <- struct{}{} + case op := <-s.peerOp: + op(peers) + s.peerOpDone <- struct{}{} + case p := <-s.register: + peers[p] = true + // When a new peer connection is established, we send + // out our version immediately. + s.sendVersion(p) + s.logger.Log("event", "peer connected", "endpoint", p.Endpoint()) + case p := <-s.unregister: + delete(peers, p) + s.logger.Log("event", "peer disconnected", "endpoint", p.Endpoint()) + case <-ticker: + s.printState() + case <-s.quit: + return } } - return false } -// TODO: Quit this routine if the peer is disconnected. -func (s *Server) startProtocol(p Peer) { - if s.bc.HeaderHeight() < p.version().StartHeight { - s.askMoreHeaders(p) - } - for { - getaddrMsg := newMessage(s.net, cmdGetAddr, nil) - p.Send(getaddrMsg) - - time.Sleep(30 * time.Second) +// PeerCount returns the number of current connected peers. +func (s *Server) PeerCount() (n int) { + s.peerOp <- func(peers map[Peer]bool) { + n = len(peers) } + <-s.peerOpDone + return } -type peerCount struct { - count chan int -} +func (s *Server) Start() error { + fmt.Println(logo()) + fmt.Println("") + s.printConfiguration() -// peerCount returns the number of connected peers to this server. -func (s *Server) peerCount() int { - ch := peerCount{ - count: make(chan int), + if err := s.createListener(); err != nil { + return err } - s.peerCountCh <- ch - - return <-ch.count + go s.run() + go s.listenTCP() + go s.connectToPeers(s.Seeds...) + select {} } -// StartOpts holds the server configuration. -type StartOpts struct { - // tcp port - TCP int - // slice of peer addresses the server will connect to - Seeds []string - // JSON-RPC port. If 0 no RPC handler will be attached. - RPC int +func (s *Server) Quit() { + s.quit <- struct{}{} +} + +func (s *Server) printState() { + w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0) + fmt.Fprintf(w, "connected peers:\t%d/%d\n", s.PeerCount(), s.MaxPeers) + w.Flush() +} + +func (s *Server) printConfiguration() { + w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0) + fmt.Fprintf(w, "user agent:\t%s\n", s.UserAgent) + fmt.Fprintf(w, "id:\t%d\n", s.id) + fmt.Fprintf(w, "network:\t%s\n", s.Net) + fmt.Fprintf(w, "listen TCP:\t%d\n", s.ListenTCP) + fmt.Fprintf(w, "listen RPC:\t%d\n", s.ListenRPC) + fmt.Fprintf(w, "relay:\t%v\n", s.Relay) + fmt.Fprintf(w, "max peers:\t%d\n", s.MaxPeers) + chainer := s.proto.(Noder) + fmt.Fprintf(w, "current height:\t%d\n", chainer.blockchain().HeaderHeight()) + fmt.Fprintln(w, "") + w.Flush() } func logo() string { diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 59a9cd0f3..3c3af5f00 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -1,60 +1,86 @@ package network import ( + "os" "testing" "github.com/CityOfZion/neo-go/pkg/network/payload" + "github.com/CityOfZion/neo-go/pkg/util" + log "github.com/go-kit/kit/log" + "github.com/stretchr/testify/assert" ) -// TODO this should be moved to localPeer test. +func TestRegisterPeer(t *testing.T) { + s := newTestServer() + go s.run() -func TestHandleVersionFailWrongPort(t *testing.T) { - s := NewServer(ModePrivNet) - go s.loop() + assert.NotZero(t, s.id) + assert.Zero(t, s.PeerCount()) - p := NewLocalPeer(s) + lenPeers := 10 + for i := 0; i < lenPeers; i++ { + s.register <- newTestPeer() + } + assert.Equal(t, lenPeers, s.PeerCount()) +} - version := payload.NewVersion(1337, 1, "/NEO:0.0.0/", 0, true) - if err := s.handleVersionCmd(version, p); err == nil { - t.Fatal("expected error got nil") +func TestUnregisterPeer(t *testing.T) { + s := newTestServer() + go s.run() + + peer := newTestPeer() + s.register <- peer + s.register <- newTestPeer() + s.register <- newTestPeer() + assert.Equal(t, 3, s.PeerCount()) + + s.unregister <- peer + assert.Equal(t, 2, s.PeerCount()) +} + +type testNode struct{} + +func (t testNode) version() *payload.Version { + return &payload.Version{} +} + +func (t testNode) handleProto(msg *Message, p Peer) {} + +func newTestServer() *Server { + return &Server{ + logger: log.NewLogfmtLogger(os.Stderr), + id: util.RandUint32(1000000, 9999999), + quit: make(chan struct{}, 1), + register: make(chan Peer), + unregister: make(chan Peer), + badAddrOp: make(chan func(map[string]bool)), + badAddrOpDone: make(chan struct{}), + peerOp: make(chan func(map[Peer]bool)), + peerOpDone: make(chan struct{}), + proto: testNode{}, } } -func TestHandleVersionFailIdenticalNonce(t *testing.T) { - s := NewServer(ModePrivNet) - go s.loop() +type testPeer struct { + done chan struct{} +} - p := NewLocalPeer(s) - - version := payload.NewVersion(s.id, 1, "/NEO:0.0.0/", 0, true) - if err := s.handleVersionCmd(version, p); err == nil { - t.Fatal("expected error got nil") +func newTestPeer() testPeer { + return testPeer{ + done: make(chan struct{}), } } -func TestHandleVersion(t *testing.T) { - s := NewServer(ModePrivNet) - go s.loop() - - p := NewLocalPeer(s) - - version := payload.NewVersion(1337, p.addr().Port, "/NEO:0.0.0/", 0, true) - if err := s.handleVersionCmd(version, p); err != nil { - t.Fatal(err) - } +func (p testPeer) Version() *payload.Version { + return &payload.Version{} } -func TestHandleAddrCmd(t *testing.T) { - // todo +func (p testPeer) Endpoint() util.Endpoint { + return util.Endpoint{} } -func TestHandleGetAddrCmd(t *testing.T) { - // todo -} +func (p testPeer) Send(msg *Message) {} -func TestHandleInv(t *testing.T) { - // todo -} -func TestHandleBlockCmd(t *testing.T) { - // todo +func (p testPeer) Done() chan struct{} { + return p.done } diff --git a/pkg/network/tcp.go b/pkg/network/tcp.go deleted file mode 100644 index 80c836585..000000000 --- a/pkg/network/tcp.go +++ /dev/null @@ -1,254 +0,0 @@ -package network - -import ( - "bytes" - "errors" - "fmt" - "net" - - "github.com/CityOfZion/neo-go/pkg/core" - "github.com/CityOfZion/neo-go/pkg/network/payload" - "github.com/CityOfZion/neo-go/pkg/util" -) - -func listenTCP(s *Server, port int) error { - ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) - if err != nil { - return err - } - - s.listener = ln - - for { - conn, err := ln.Accept() - if err != nil { - return err - } - - go handleConnection(s, conn) - } -} - -func connectToRemoteNode(s *Server, address string) { - conn, err := net.Dial("tcp", address) - if err != nil { - s.logger.Printf("failed to connect to remote node %s", address) - if conn != nil { - conn.Close() - } - return - } - go handleConnection(s, conn) -} - -func connectToSeeds(s *Server, addrs []string) { - for _, addr := range addrs { - go connectToRemoteNode(s, addr) - } -} - -func handleConnection(s *Server, conn net.Conn) { - peer := NewTCPPeer(conn, s) - s.register <- peer - - // remove the peer from connected peers and cleanup the connection. - defer func() { - peer.disconnect() - }() - - // Start a goroutine that will handle all outgoing messages. - go peer.writeLoop() - // Start a goroutine that will handle all incomming messages. - go handleMessage(s, peer) - - // Read from the connection and decode it into a Message ready for processing. - for { - msg := &Message{} - if err := msg.decode(conn); err != nil { - s.logger.Printf("decode error: %s", err) - break - } - - peer.receive <- msg - } -} - -// handleMessage multiplexes the message received from a TCP connection to a server command. -func handleMessage(s *Server, p *TCPPeer) { - var err error - - for { - msg := <-p.receive - command := msg.commandType() - - // s.logger.Printf("IN :: %d :: %s :: %v", p.id(), command, msg) - - switch command { - case cmdVersion: - version := msg.Payload.(*payload.Version) - if err = s.handleVersionCmd(version, p); err != nil { - break - } - p.nonce = version.Nonce - p.pVersion = version - - // When a node receives a connection request, it declares its version immediately. - // There will be no other communication until both sides are getting versions of each other. - // When a node receives the version message, it replies to a verack as a response immediately. - // NOTE: The current official NEO nodes dont mimic this behaviour. There is small chance that the - // official nodes will not respond directly with a verack after we sended our version. - // is this a bug? - anthdm 02/02/2018 - msgVerack := <-p.receive - if msgVerack.commandType() != cmdVerack { - err = errors.New("expected verack after sended out version") - break - } - - // start the protocol - go s.startProtocol(p) - case cmdAddr: - addrList := msg.Payload.(*payload.AddressList) - err = s.handleAddrCmd(addrList, p) - case cmdGetAddr: - err = s.handleGetaddrCmd(msg, p) - case cmdInv: - inv := msg.Payload.(*payload.Inventory) - err = s.handleInvCmd(inv, p) - case cmdBlock: - block := msg.Payload.(*core.Block) - err = s.handleBlockCmd(block, p) - case cmdConsensus: - case cmdTX: - case cmdVerack: - // If we receive a verack here we disconnect. We already handled the verack - // when we sended our version. - err = errors.New("verack already received") - case cmdGetHeaders: - case cmdGetBlocks: - case cmdGetData: - case cmdHeaders: - headers := msg.Payload.(*payload.Headers) - err = s.handleHeadersCmd(headers, p) - default: - // This command is unknown by the server. - err = fmt.Errorf("unknown command received %v", msg.Command) - break - } - - // catch all errors here and disconnect. - if err != nil { - s.logger.Printf("processing message failed: %s", err) - break - } - } - - // Disconnect the peer when breaked out of the loop. - p.disconnect() -} - -type sendTuple struct { - msg *Message - err chan error -} - -// TCPPeer represents a remote node, backed by TCP transport. -type TCPPeer struct { - s *Server - // nonce (id) of the peer. - nonce uint32 - // underlying TCP connection - conn net.Conn - // host and port information about this peer. - endpoint util.Endpoint - // channel to coordinate messages writen back to the connection. - send chan sendTuple - // channel to receive from underlying connection. - receive chan *Message - // the version sended out by the peer when connected. - pVersion *payload.Version -} - -// NewTCPPeer returns a pointer to a TCP Peer. -func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { - e, _ := util.EndpointFromString(conn.RemoteAddr().String()) - - return &TCPPeer{ - conn: conn, - send: make(chan sendTuple), - receive: make(chan *Message), - endpoint: e, - s: s, - } -} - -// Send needed to implement the network.Peer interface -// and provide the functionality to send a message to -// the current peer. -func (p *TCPPeer) Send(msg *Message) error { - t := sendTuple{ - msg: msg, - err: make(chan error), - } - - p.send <- t - - return <-t.err -} - -func (p *TCPPeer) version() *payload.Version { - return p.pVersion -} - -// id implements the peer interface -func (p *TCPPeer) id() uint32 { - return p.nonce -} - -// endpoint implements the peer interface -func (p *TCPPeer) addr() util.Endpoint { - return p.endpoint -} - -// disconnect disconnects the peer, cleaning up all its resources. -// 3 goroutines needs to be cleanup (writeLoop, handleConnection and handleMessage) -func (p *TCPPeer) disconnect() { - select { - case <-p.send: - case <-p.receive: - default: - close(p.send) - close(p.receive) - p.s.unregister <- p - p.conn.Close() - } -} - -// writeLoop writes messages to the underlying TCP connection. -// A goroutine writeLoop is started for each connection. -// There should be at most one writer to a connection executing -// all writes from this goroutine. -func (p *TCPPeer) writeLoop() { - // clean up the connection. - defer func() { - p.disconnect() - }() - - // resuse this buffer - buf := new(bytes.Buffer) - for { - t := <-p.send - if t.msg == nil { - break // send probably closed. - } - - // p.s.logger.Printf("OUT :: %s :: %+v", t.msg.commandType(), t.msg.Payload) - - if err := t.msg.encode(buf); err != nil { - t.err <- err - } - _, err := p.conn.Write(buf.Bytes()) - t.err <- err - - buf.Reset() - } -} diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go new file mode 100644 index 000000000..efce677d9 --- /dev/null +++ b/pkg/network/tcp_peer.go @@ -0,0 +1,134 @@ +package network + +import ( + "bytes" + "net" + "os" + "time" + + "github.com/CityOfZion/neo-go/pkg/network/payload" + "github.com/CityOfZion/neo-go/pkg/util" + log "github.com/go-kit/kit/log" +) + +// TCPPeer represents a connected remote node in the +// network over TCP. +type TCPPeer struct { + // The endpoint of the peer. + endpoint util.Endpoint + + // underlying connection. + conn net.Conn + + // The version the peer declared when connecting. + version *payload.Version + + // connectedAt is the timestamp the peer connected to + // the network. + connectedAt time.Time + + // handleProto is the handler that will handle the + // incoming message along with its peer. + handleProto protoHandleFunc + + // Done is used to broadcast this peer has stopped running + // and should be removed as reference. + done chan struct{} + send chan *Message + + logger log.Logger +} + +// NewTCPPeer creates a new peer from a TCP connection. +func NewTCPPeer(conn net.Conn, fun protoHandleFunc) *TCPPeer { + e := util.NewEndpoint(conn.RemoteAddr().String()) + logger := log.NewLogfmtLogger(os.Stderr) + logger = log.With(logger, "component", "peer", "endpoint", e) + + return &TCPPeer{ + endpoint: e, + conn: conn, + done: make(chan struct{}), + send: make(chan *Message), + logger: logger, + connectedAt: time.Now().UTC(), + handleProto: fun, + } +} + +// Version implements the Peer interface. +func (p *TCPPeer) Version() *payload.Version { + return p.version +} + +// Endpoint implements the Peer interface. +func (p *TCPPeer) Endpoint() util.Endpoint { + return p.endpoint +} + +// Send implements the Peer interface. +func (p *TCPPeer) Send(msg *Message) { + p.send <- msg +} + +// Done implemnets the Peer interface. +func (p *TCPPeer) Done() chan struct{} { + return p.done +} + +func (p *TCPPeer) run() error { + errCh := make(chan error, 1) + + go p.readLoop(errCh) + go p.writeLoop(errCh) + + err := <-errCh + p.logger.Log("err", err) + p.cleanup() + return err +} + +func (p *TCPPeer) readLoop(errCh chan error) { + for { + msg := &Message{} + if err := msg.decode(p.conn); err != nil { + errCh <- err + break + } + p.handleMessage(msg) + } +} + +func (p *TCPPeer) writeLoop(errCh chan error) { + buf := new(bytes.Buffer) + + for { + msg := <-p.send + if err := msg.encode(buf); err != nil { + errCh <- err + break + } + if _, err := p.conn.Write(buf.Bytes()); err != nil { + errCh <- err + break + } + buf.Reset() + } +} + +func (p *TCPPeer) cleanup() { + p.conn.Close() + close(p.send) + p.done <- struct{}{} +} + +func (p *TCPPeer) handleMessage(msg *Message) { + switch msg.CommandType() { + case CMDVersion: + version := msg.Payload.(*payload.Version) + p.version = version + p.handleProto(msg, p) + default: + p.handleProto(msg, p) + } +} diff --git a/pkg/util/endpoint.go b/pkg/util/endpoint.go index a6b5030d5..bfbb9831b 100644 --- a/pkg/util/endpoint.go +++ b/pkg/util/endpoint.go @@ -12,14 +12,11 @@ type Endpoint struct { Port uint16 } -// EndpointFromString returns an Endpoint from the given string. -// For now this only handles the most simple hostport form. -// e.g. 127.0.0.1:3000 -// This should be enough to work with for now. -func EndpointFromString(s string) (Endpoint, error) { +// NewEndpoint creates an Endpoint from the given string. +func NewEndpoint(s string) (e Endpoint) { hostPort := strings.Split(s, ":") if len(hostPort) != 2 { - return Endpoint{}, fmt.Errorf("invalid address string: %s", s) + return e } host := hostPort[0] port := hostPort[1] @@ -36,7 +33,7 @@ func EndpointFromString(s string) (Endpoint, error) { p, _ := strconv.Atoi(port) - return Endpoint{buf, uint16(p)}, nil + return Endpoint{buf, uint16(p)} } // Network implements the net.Addr interface. diff --git a/pkg/util/uint256.go b/pkg/util/uint256.go index d082d97cf..c98bcc658 100644 --- a/pkg/util/uint256.go +++ b/pkg/util/uint256.go @@ -34,7 +34,7 @@ func Uint256DecodeBytes(b []byte) (u Uint256, err error) { return u, nil } -// ToSlice returns a byte slice representation of u. +// Bytes returns a byte slice representation of u. func (u Uint256) Bytes() []byte { b := make([]byte, uint256Size) for i := 0; i < uint256Size; i++ { diff --git a/pkg/wallet/wif.go b/pkg/wallet/wif.go index 97f179256..44f919114 100644 --- a/pkg/wallet/wif.go +++ b/pkg/wallet/wif.go @@ -8,7 +8,7 @@ import ( ) const ( - // The WIF network version used to decode and encode WIF keys. + // WIFVersion is the version used to decode and encode WIF keys. WIFVersion = 0x80 )