[Peer] Refactor (#240)

[Peer]

- Closes #239

- moved response handlers to their own functions

- removed DefaultConfig from LocalConfig file

- passed peer as a parameter to all response handlers

- added peer start height

- refactored NewPeer function to be more concise and clear

- removed empty lines at end of functions

- Added AddMessage/RemoveMessage for Detector in outgoing and ingoing
requests for Block and Headers
This commit is contained in:
decentralisedkev 2019-03-28 19:09:55 +00:00 committed by GitHub
parent ce1fe72607
commit dc5de1fa6d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 165 additions and 196 deletions

View file

@ -14,34 +14,18 @@ type LocalConfig struct {
ProtocolVer protocol.Version ProtocolVer protocol.Version
Relay bool Relay bool
Port uint16 Port uint16
// pointer to config will keep the startheight updated for each version
//Message we plan to send // pointer to config will keep the startheight updated
StartHeight func() uint32 StartHeight func() uint32
// Response Handlers
OnHeader func(*Peer, *payload.HeadersMessage) OnHeader func(*Peer, *payload.HeadersMessage)
OnGetHeaders func(msg *payload.GetHeadersMessage) // returns HeaderMessage OnGetHeaders func(*Peer, *payload.GetHeadersMessage)
OnAddr func(*Peer, *payload.AddrMessage) OnAddr func(*Peer, *payload.AddrMessage)
OnGetAddr func(*Peer, *payload.GetAddrMessage) OnGetAddr func(*Peer, *payload.GetAddrMessage)
OnInv func(*Peer, *payload.InvMessage) OnInv func(*Peer, *payload.InvMessage)
OnGetData func(msg *payload.GetDataMessage) OnGetData func(*Peer, *payload.GetDataMessage)
OnBlock func(*Peer, *payload.BlockMessage) OnBlock func(*Peer, *payload.BlockMessage)
OnGetBlocks func(msg *payload.GetBlocksMessage) OnGetBlocks func(*Peer, *payload.GetBlocksMessage)
OnTx func(*Peer, *payload.TXMessage)
} }
// func DefaultConfig() LocalConfig {
// return LocalConfig{
// Net: protocol.MainNet,
// UserAgent: "NEO-GO-Default",
// Services: protocol.NodePeerService,
// Nonce: 1200,
// ProtocolVer: 0,
// Relay: false,
// Port: 10332,
// // pointer to config will keep the startheight updated for each version
// //Message we plan to send
// StartHeight: DefaultHeight,
// }
// }
// func DefaultHeight() uint32 {
// return 10
// }

View file

@ -58,6 +58,8 @@ type Peer struct {
config LocalConfig config LocalConfig
conn net.Conn conn net.Conn
startHeight uint32
// atomic vals // atomic vals
disconnected int32 disconnected int32
@ -84,20 +86,18 @@ type Peer struct {
// NewPeer returns a new NEO peer // NewPeer returns a new NEO peer
func NewPeer(con net.Conn, inbound bool, cfg LocalConfig) *Peer { func NewPeer(con net.Conn, inbound bool, cfg LocalConfig) *Peer {
p := Peer{} return &Peer{
p.inch = make(chan func(), inputBufferSize) inch: make(chan func(), inputBufferSize),
p.outch = make(chan func(), outputBufferSize) outch: make(chan func(), outputBufferSize),
p.quitch = make(chan struct{}, 1) quitch: make(chan struct{}, 1),
p.inbound = inbound inbound: inbound,
p.config = cfg config: cfg,
p.conn = con conn: con,
p.createdAt = time.Now() createdAt: time.Now(),
p.addr = p.conn.RemoteAddr().String() startHeight: 0,
addr: con.RemoteAddr().String(),
p.Detector = stall.NewDetector(responseTime, tickerInterval) Detector: stall.NewDetector(responseTime, tickerInterval),
}
// TODO: set the unchangeable states
return &p
} }
// Write to a peer // Write to a peer
@ -125,7 +125,6 @@ func (p *Peer) Disconnect() {
p.conn.Close() p.conn.Close()
fmt.Println("Disconnected Peer with address", p.RemoteAddr().String()) fmt.Println("Disconnected Peer with address", p.RemoteAddr().String())
} }
// Port returns the peers port // Port returns the peers port
@ -138,6 +137,11 @@ func (p *Peer) CreatedAt() time.Time {
return p.createdAt return p.createdAt
} }
// Height returns the latest recorded height of this peer
func (p *Peer) Height() uint32 {
return p.startHeight
}
// CanRelay returns true, if the peer can relay information // CanRelay returns true, if the peer can relay information
func (p *Peer) CanRelay() bool { func (p *Peer) CanRelay() bool {
return p.relay return p.relay
@ -163,11 +167,6 @@ func (p *Peer) Inbound() bool {
return p.inbound return p.inbound
} }
// UserAgent returns this nodes, useragent
func (p *Peer) UserAgent() string {
return p.config.UserAgent
}
// IsVerackReceived returns true, if this node has // IsVerackReceived returns true, if this node has
// received a verack from this peer // received a verack from this peer
func (p *Peer) IsVerackReceived() bool { func (p *Peer) IsVerackReceived() bool {
@ -204,7 +203,6 @@ func (p *Peer) Run() error {
//go p.PingLoop() // since it is not implemented. It will disconnect all other impls. //go p.PingLoop() // since it is not implemented. It will disconnect all other impls.
return nil return nil
} }
// StartProtocol run as a go-routine, will act as our queue for messages // StartProtocol run as a go-routine, will act as our queue for messages
@ -305,128 +303,17 @@ func (p *Peer) WriteLoop() {
} }
} }
// OnGetData is called when a GetData message is received // Outgoing Requests
func (p *Peer) OnGetData(msg *payload.GetDataMessage) {
p.inch <- func() {
if p.config.OnInv != nil {
p.config.OnGetData(msg)
}
fmt.Println("That was an getdata Message please pass func down through config", msg.Command())
}
}
//OnTX is callwed when a TX message is received
func (p *Peer) OnTX(msg *payload.TXMessage) {
p.inch <- func() {
getdata, err := payload.NewGetDataMessage(payload.InvTypeTx)
if err != nil {
fmt.Println("Eor", err)
}
id, err := msg.Tx.ID()
getdata.AddHash(id)
p.Write(getdata)
}
}
// OnInv is called when a Inv message is received
func (p *Peer) OnInv(msg *payload.InvMessage) {
p.inch <- func() {
if p.config.OnInv != nil {
p.config.OnInv(p, msg)
}
fmt.Println("That was an inv Message please pass func down through config", msg.Command())
}
}
// OnGetHeaders is called when a GetHeaders message is received
func (p *Peer) OnGetHeaders(msg *payload.GetHeadersMessage) {
p.inch <- func() {
if p.config.OnGetHeaders != nil {
p.config.OnGetHeaders(msg)
}
fmt.Println("That was a getheaders message, please pass func down through config", msg.Command())
}
}
// OnAddr is called when a Addr message is received
func (p *Peer) OnAddr(msg *payload.AddrMessage) {
p.inch <- func() {
if p.config.OnAddr != nil {
p.config.OnAddr(p, msg)
}
fmt.Println("That was a addr message, please pass func down through config", msg.Command())
}
}
// OnGetAddr is called when a GetAddr message is received
func (p *Peer) OnGetAddr(msg *payload.GetAddrMessage) {
p.inch <- func() {
if p.config.OnGetAddr != nil {
p.config.OnGetAddr(p, msg)
}
fmt.Println("That was a getaddr message, please pass func down through config", msg.Command())
}
}
// OnGetBlocks is called when a GetBlocks message is received
func (p *Peer) OnGetBlocks(msg *payload.GetBlocksMessage) {
p.inch <- func() {
if p.config.OnGetBlocks != nil {
p.config.OnGetBlocks(msg)
}
fmt.Println("That was a getblocks message, please pass func down through config", msg.Command())
}
}
// OnBlocks is called when a Blocks message is received
func (p *Peer) OnBlocks(msg *payload.BlockMessage) {
p.inch <- func() {
if p.config.OnBlock != nil {
p.config.OnBlock(p, msg)
}
}
}
// OnVersion Listener will be called
// during the handshake, any error checking should be done here for the versionMessage.
// This should only ever be called during the handshake. Any other place and the peer will disconnect.
func (p *Peer) OnVersion(msg *payload.VersionMessage) error {
if msg.Nonce == p.config.Nonce {
p.conn.Close()
return errors.New("Self connection, disconnecting Peer")
}
p.versionKnown = true
p.port = msg.Port
p.services = msg.Services
p.userAgent = string(msg.UserAgent)
p.createdAt = time.Now()
p.relay = msg.Relay
return nil
}
// OnHeaders is called when a Headers message is received
func (p *Peer) OnHeaders(msg *payload.HeadersMessage) {
fmt.Println("We have received the headers")
p.inch <- func() {
if p.config.OnHeader != nil {
p.config.OnHeader(p, msg)
}
}
}
// RequestHeaders will write a getheaders to this peer // RequestHeaders will write a getheaders to this peer
func (p *Peer) RequestHeaders(hash util.Uint256) error { func (p *Peer) RequestHeaders(hash util.Uint256) error {
c := make(chan error, 0) c := make(chan error, 0)
p.outch <- func() { p.outch <- func() {
p.Detector.AddMessage(command.GetHeaders)
getHeaders, err := payload.NewGetHeadersMessage([]util.Uint256{hash}, util.Uint256{}) getHeaders, err := payload.NewGetHeadersMessage([]util.Uint256{hash}, util.Uint256{})
err = p.Write(getHeaders) err = p.Write(getHeaders)
if err != nil {
p.Detector.AddMessage(command.GetHeaders)
}
c <- err c <- err
} }
return <-c return <-c
@ -437,17 +324,19 @@ func (p *Peer) RequestBlocks(hashes []util.Uint256) error {
c := make(chan error, 0) c := make(chan error, 0)
p.outch <- func() { p.outch <- func() {
p.Detector.AddMessage(command.GetData)
getdata, err := payload.NewGetDataMessage(payload.InvTypeBlock) getdata, err := payload.NewGetDataMessage(payload.InvTypeBlock)
err = getdata.AddHashes(hashes) err = getdata.AddHashes(hashes)
if err != nil { if err != nil {
c <- err c <- err
return return
} }
err = p.Write(getdata) err = p.Write(getdata)
if err != nil {
p.Detector.AddMessage(command.GetData)
}
c <- err c <- err
} }
return <-c return <-c
} }

View file

@ -1,7 +1,6 @@
package peer_test package peer_test
import ( import (
"fmt"
"net" "net"
"testing" "testing"
"time" "time"
@ -21,11 +20,11 @@ func returnConfig() peer.LocalConfig {
OnAddr := func(p *peer.Peer, msg *payload.AddrMessage) {} OnAddr := func(p *peer.Peer, msg *payload.AddrMessage) {}
OnHeader := func(p *peer.Peer, msg *payload.HeadersMessage) {} OnHeader := func(p *peer.Peer, msg *payload.HeadersMessage) {}
OnGetHeaders := func(msg *payload.GetHeadersMessage) {} OnGetHeaders := func(p *peer.Peer, msg *payload.GetHeadersMessage) {}
OnInv := func(p *peer.Peer, msg *payload.InvMessage) {} OnInv := func(p *peer.Peer, msg *payload.InvMessage) {}
OnGetData := func(msg *payload.GetDataMessage) {} OnGetData := func(p *peer.Peer, msg *payload.GetDataMessage) {}
OnBlock := func(p *peer.Peer, msg *payload.BlockMessage) {} OnBlock := func(p *peer.Peer, msg *payload.BlockMessage) {}
OnGetBlocks := func(msg *payload.GetBlocksMessage) {} OnGetBlocks := func(p *peer.Peer, msg *payload.GetBlocksMessage) {}
return peer.LocalConfig{ return peer.LocalConfig{
Net: protocol.MainNet, Net: protocol.MainNet,
@ -157,17 +156,9 @@ func TestConfigurations(t *testing.T) {
assert.Equal(t, config.Services, p.Services()) assert.Equal(t, config.Services, p.Services())
assert.Equal(t, config.UserAgent, p.UserAgent())
assert.Equal(t, config.Relay, p.CanRelay()) assert.Equal(t, config.Relay, p.CanRelay())
assert.WithinDuration(t, time.Now(), p.CreatedAt(), 1*time.Second) assert.WithinDuration(t, time.Now(), p.CreatedAt(), 1*time.Second)
}
func TestHandshakeCancelled(t *testing.T) {
// These are the conditions which should invalidate the handshake.
// Make sure peer is disconnected.
} }
func TestPeerDisconnect(t *testing.T) { func TestPeerDisconnect(t *testing.T) {
@ -178,21 +169,17 @@ func TestPeerDisconnect(t *testing.T) {
inbound := true inbound := true
config := returnConfig() config := returnConfig()
p := peer.NewPeer(conn, inbound, config) p := peer.NewPeer(conn, inbound, config)
fmt.Println("Calling disconnect")
p.Disconnect() p.Disconnect()
fmt.Println("Disconnect finished calling") verack, err := payload.NewVerackMessage()
verack, _ := payload.NewVerackMessage() assert.Nil(t, err)
fmt.Println(" We good here") err = p.Write(verack)
assert.NotNil(t, err)
err := p.Write(verack) // Check if stall detector is still running
assert.NotEqual(t, err, nil)
// Check if Stall detector is still running
_, ok := <-p.Detector.Quitch _, ok := <-p.Detector.Quitch
assert.Equal(t, ok, false) assert.Equal(t, ok, false)
} }
func TestNotifyDisconnect(t *testing.T) { func TestNotifyDisconnect(t *testing.T) {

View file

@ -0,0 +1,111 @@
package peer
import (
"errors"
"time"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
)
// OnGetData is called when a GetData message is received
func (p *Peer) OnGetData(msg *payload.GetDataMessage) {
p.inch <- func() {
if p.config.OnInv != nil {
p.config.OnGetData(p, msg)
}
}
}
//OnTX is called when a TX message is received
func (p *Peer) OnTX(msg *payload.TXMessage) {
p.inch <- func() {
p.inch <- func() {
if p.config.OnTx != nil {
p.config.OnTx(p, msg)
}
}
}
}
// OnInv is called when a Inv message is received
func (p *Peer) OnInv(msg *payload.InvMessage) {
p.inch <- func() {
if p.config.OnInv != nil {
p.config.OnInv(p, msg)
}
}
}
// OnGetHeaders is called when a GetHeaders message is received
func (p *Peer) OnGetHeaders(msg *payload.GetHeadersMessage) {
p.inch <- func() {
if p.config.OnGetHeaders != nil {
p.config.OnGetHeaders(p, msg)
}
}
}
// OnAddr is called when a Addr message is received
func (p *Peer) OnAddr(msg *payload.AddrMessage) {
p.inch <- func() {
if p.config.OnAddr != nil {
p.config.OnAddr(p, msg)
}
}
}
// OnGetAddr is called when a GetAddr message is received
func (p *Peer) OnGetAddr(msg *payload.GetAddrMessage) {
p.inch <- func() {
if p.config.OnGetAddr != nil {
p.config.OnGetAddr(p, msg)
}
}
}
// OnGetBlocks is called when a GetBlocks message is received
func (p *Peer) OnGetBlocks(msg *payload.GetBlocksMessage) {
p.inch <- func() {
if p.config.OnGetBlocks != nil {
p.config.OnGetBlocks(p, msg)
}
}
}
// OnBlocks is called when a Blocks message is received
func (p *Peer) OnBlocks(msg *payload.BlockMessage) {
p.Detector.RemoveMessage(msg.Command())
p.inch <- func() {
if p.config.OnBlock != nil {
p.config.OnBlock(p, msg)
}
}
}
// OnHeaders is called when a Headers message is received
func (p *Peer) OnHeaders(msg *payload.HeadersMessage) {
p.Detector.RemoveMessage(msg.Command())
p.inch <- func() {
if p.config.OnHeader != nil {
p.config.OnHeader(p, msg)
}
}
}
// OnVersion Listener will be called
// during the handshake, any error checking should be done here for the versionMessage.
// This should only ever be called during the handshake. Any other place and the peer will disconnect.
func (p *Peer) OnVersion(msg *payload.VersionMessage) error {
if msg.Nonce == p.config.Nonce {
p.conn.Close()
return errors.New("self connection, disconnecting Peer")
}
p.versionKnown = true
p.port = msg.Port
p.services = msg.Services
p.userAgent = string(msg.UserAgent)
p.createdAt = time.Now()
p.relay = msg.Relay
p.startHeight = msg.StartHeight
return nil
}

View file

@ -61,6 +61,7 @@ func (d *Detector) loop() {
d.lock.RUnlock() d.lock.RUnlock()
for _, deadline := range resp { for _, deadline := range resp {
if now.After(deadline) { if now.After(deadline) {
fmt.Println(resp)
fmt.Println("Deadline passed") fmt.Println("Deadline passed")
return return
} }
@ -99,7 +100,7 @@ func (d *Detector) AddMessage(cmd command.Type) {
// peer. This will remove the pendingresponse message from the map. // peer. This will remove the pendingresponse message from the map.
// The command passed through is the command we received // The command passed through is the command we received
func (d *Detector) RemoveMessage(cmd command.Type) { func (d *Detector) RemoveMessage(cmd command.Type) {
cmds := d.addMessage(cmd) cmds := d.removeMessage(cmd)
d.lock.Lock() d.lock.Lock()
for _, cmd := range cmds { for _, cmd := range cmds {
delete(d.responses, cmd) delete(d.responses, cmd)
@ -137,10 +138,8 @@ func (d *Detector) addMessage(cmd command.Type) []command.Type {
case command.GetAddr: case command.GetAddr:
// We now will expect a Headers Message // We now will expect a Headers Message
cmds = append(cmds, command.Addr) cmds = append(cmds, command.Addr)
case command.GetData: case command.GetData:
// We will now expect a block/tx message // We will now expect a block/tx message
// We can optimise this by including the exact inventory type, however it is not needed
cmds = append(cmds, command.Block) cmds = append(cmds, command.Block)
cmds = append(cmds, command.TX) cmds = append(cmds, command.TX)
case command.GetBlocks: case command.GetBlocks:
@ -159,19 +158,18 @@ func (d *Detector) removeMessage(cmd command.Type) []command.Type {
switch cmd { switch cmd {
case command.Block: case command.Block:
// We will now expect a block/tx message // We will now remove a block and tx message
cmds = append(cmds, command.Block) cmds = append(cmds, command.Block)
cmds = append(cmds, command.TX) cmds = append(cmds, command.TX)
case command.TX: case command.TX:
// We will now expect a block/tx message // We will now remove a block and tx message
cmds = append(cmds, command.Block) cmds = append(cmds, command.Block)
cmds = append(cmds, command.TX) cmds = append(cmds, command.TX)
case command.GetBlocks: case command.Verack:
// we will now expect a inv message
cmds = append(cmds, command.Inv)
default:
// We will now expect a verack // We will now expect a verack
cmds = append(cmds, cmd) cmds = append(cmds, cmd)
default:
cmds = append(cmds, cmd)
} }
return cmds return cmds
} }

View file

@ -22,7 +22,7 @@ func TestAddRemoveMessage(t *testing.T) {
assert.Equal(t, 1, len(mp)) assert.Equal(t, 1, len(mp))
assert.IsType(t, time.Time{}, mp[command.GetAddr]) assert.IsType(t, time.Time{}, mp[command.GetAddr])
d.RemoveMessage(command.GetAddr) d.RemoveMessage(command.Addr)
mp = d.GetMessages() mp = d.GetMessages()
assert.Equal(t, 0, len(mp)) assert.Equal(t, 0, len(mp))