diff --git a/pkg/peer/config.go b/pkg/peer/config.go index e4794c282..47c688741 100644 --- a/pkg/peer/config.go +++ b/pkg/peer/config.go @@ -14,34 +14,18 @@ type LocalConfig struct { ProtocolVer protocol.Version Relay bool Port uint16 - // pointer to config will keep the startheight updated for each version - //Message we plan to send - StartHeight func() uint32 + + // pointer to config will keep the startheight updated + StartHeight func() uint32 + + // Response Handlers OnHeader func(*Peer, *payload.HeadersMessage) - OnGetHeaders func(msg *payload.GetHeadersMessage) // returns HeaderMessage + OnGetHeaders func(*Peer, *payload.GetHeadersMessage) OnAddr func(*Peer, *payload.AddrMessage) OnGetAddr func(*Peer, *payload.GetAddrMessage) OnInv func(*Peer, *payload.InvMessage) - OnGetData func(msg *payload.GetDataMessage) + OnGetData func(*Peer, *payload.GetDataMessage) OnBlock func(*Peer, *payload.BlockMessage) - OnGetBlocks func(msg *payload.GetBlocksMessage) + OnGetBlocks func(*Peer, *payload.GetBlocksMessage) + OnTx func(*Peer, *payload.TXMessage) } - -// func DefaultConfig() LocalConfig { -// return LocalConfig{ -// Net: protocol.MainNet, -// UserAgent: "NEO-GO-Default", -// Services: protocol.NodePeerService, -// Nonce: 1200, -// ProtocolVer: 0, -// Relay: false, -// Port: 10332, -// // pointer to config will keep the startheight updated for each version -// //Message we plan to send -// StartHeight: DefaultHeight, -// } -// } - -// func DefaultHeight() uint32 { -// return 10 -// } diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 035dad835..94f9c685c 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -58,6 +58,8 @@ type Peer struct { config LocalConfig conn net.Conn + startHeight uint32 + // atomic vals disconnected int32 @@ -84,20 +86,18 @@ type Peer struct { // NewPeer returns a new NEO peer func NewPeer(con net.Conn, inbound bool, cfg LocalConfig) *Peer { - p := Peer{} - p.inch = make(chan func(), inputBufferSize) - p.outch = make(chan func(), outputBufferSize) - p.quitch = make(chan struct{}, 1) - p.inbound = inbound - p.config = cfg - p.conn = con - p.createdAt = time.Now() - p.addr = p.conn.RemoteAddr().String() - - p.Detector = stall.NewDetector(responseTime, tickerInterval) - - // TODO: set the unchangeable states - return &p + return &Peer{ + inch: make(chan func(), inputBufferSize), + outch: make(chan func(), outputBufferSize), + quitch: make(chan struct{}, 1), + inbound: inbound, + config: cfg, + conn: con, + createdAt: time.Now(), + startHeight: 0, + addr: con.RemoteAddr().String(), + Detector: stall.NewDetector(responseTime, tickerInterval), + } } // Write to a peer @@ -125,7 +125,6 @@ func (p *Peer) Disconnect() { p.conn.Close() fmt.Println("Disconnected Peer with address", p.RemoteAddr().String()) - } // Port returns the peers port @@ -138,6 +137,11 @@ func (p *Peer) CreatedAt() time.Time { return p.createdAt } +// Height returns the latest recorded height of this peer +func (p *Peer) Height() uint32 { + return p.startHeight +} + // CanRelay returns true, if the peer can relay information func (p *Peer) CanRelay() bool { return p.relay @@ -163,11 +167,6 @@ func (p *Peer) Inbound() bool { return p.inbound } -// UserAgent returns this nodes, useragent -func (p *Peer) UserAgent() string { - return p.config.UserAgent -} - // IsVerackReceived returns true, if this node has // received a verack from this peer func (p *Peer) IsVerackReceived() bool { @@ -204,7 +203,6 @@ func (p *Peer) Run() error { //go p.PingLoop() // since it is not implemented. It will disconnect all other impls. return nil - } // StartProtocol run as a go-routine, will act as our queue for messages @@ -305,128 +303,17 @@ func (p *Peer) WriteLoop() { } } -// OnGetData is called when a GetData message is received -func (p *Peer) OnGetData(msg *payload.GetDataMessage) { - - p.inch <- func() { - if p.config.OnInv != nil { - p.config.OnGetData(msg) - } - fmt.Println("That was an getdata Message please pass func down through config", msg.Command()) - } -} - -//OnTX is callwed when a TX message is received -func (p *Peer) OnTX(msg *payload.TXMessage) { - - p.inch <- func() { - getdata, err := payload.NewGetDataMessage(payload.InvTypeTx) - if err != nil { - fmt.Println("Eor", err) - } - id, err := msg.Tx.ID() - getdata.AddHash(id) - p.Write(getdata) - } -} - -// OnInv is called when a Inv message is received -func (p *Peer) OnInv(msg *payload.InvMessage) { - - p.inch <- func() { - if p.config.OnInv != nil { - p.config.OnInv(p, msg) - } - fmt.Println("That was an inv Message please pass func down through config", msg.Command()) - } -} - -// OnGetHeaders is called when a GetHeaders message is received -func (p *Peer) OnGetHeaders(msg *payload.GetHeadersMessage) { - p.inch <- func() { - if p.config.OnGetHeaders != nil { - p.config.OnGetHeaders(msg) - } - fmt.Println("That was a getheaders message, please pass func down through config", msg.Command()) - - } -} - -// OnAddr is called when a Addr message is received -func (p *Peer) OnAddr(msg *payload.AddrMessage) { - p.inch <- func() { - if p.config.OnAddr != nil { - p.config.OnAddr(p, msg) - } - fmt.Println("That was a addr message, please pass func down through config", msg.Command()) - - } -} - -// OnGetAddr is called when a GetAddr message is received -func (p *Peer) OnGetAddr(msg *payload.GetAddrMessage) { - p.inch <- func() { - if p.config.OnGetAddr != nil { - p.config.OnGetAddr(p, msg) - } - fmt.Println("That was a getaddr message, please pass func down through config", msg.Command()) - - } -} - -// OnGetBlocks is called when a GetBlocks message is received -func (p *Peer) OnGetBlocks(msg *payload.GetBlocksMessage) { - p.inch <- func() { - if p.config.OnGetBlocks != nil { - p.config.OnGetBlocks(msg) - } - fmt.Println("That was a getblocks message, please pass func down through config", msg.Command()) - } -} - -// OnBlocks is called when a Blocks message is received -func (p *Peer) OnBlocks(msg *payload.BlockMessage) { - p.inch <- func() { - if p.config.OnBlock != nil { - p.config.OnBlock(p, msg) - } - } -} - -// OnVersion Listener will be called -// during the handshake, any error checking should be done here for the versionMessage. -// This should only ever be called during the handshake. Any other place and the peer will disconnect. -func (p *Peer) OnVersion(msg *payload.VersionMessage) error { - if msg.Nonce == p.config.Nonce { - p.conn.Close() - return errors.New("Self connection, disconnecting Peer") - } - p.versionKnown = true - p.port = msg.Port - p.services = msg.Services - p.userAgent = string(msg.UserAgent) - p.createdAt = time.Now() - p.relay = msg.Relay - return nil -} - -// OnHeaders is called when a Headers message is received -func (p *Peer) OnHeaders(msg *payload.HeadersMessage) { - fmt.Println("We have received the headers") - p.inch <- func() { - if p.config.OnHeader != nil { - p.config.OnHeader(p, msg) - } - } -} +// Outgoing Requests // RequestHeaders will write a getheaders to this peer func (p *Peer) RequestHeaders(hash util.Uint256) error { c := make(chan error, 0) p.outch <- func() { - p.Detector.AddMessage(command.GetHeaders) getHeaders, err := payload.NewGetHeadersMessage([]util.Uint256{hash}, util.Uint256{}) err = p.Write(getHeaders) + if err != nil { + p.Detector.AddMessage(command.GetHeaders) + } c <- err } return <-c @@ -437,17 +324,19 @@ func (p *Peer) RequestBlocks(hashes []util.Uint256) error { c := make(chan error, 0) p.outch <- func() { - p.Detector.AddMessage(command.GetData) getdata, err := payload.NewGetDataMessage(payload.InvTypeBlock) err = getdata.AddHashes(hashes) if err != nil { c <- err return } + err = p.Write(getdata) + if err != nil { + p.Detector.AddMessage(command.GetData) + } + c <- err } - return <-c - } diff --git a/pkg/peer/peer_test.go b/pkg/peer/peer_test.go index adcb4d342..fd27aeb55 100644 --- a/pkg/peer/peer_test.go +++ b/pkg/peer/peer_test.go @@ -1,7 +1,6 @@ package peer_test import ( - "fmt" "net" "testing" "time" @@ -21,11 +20,11 @@ func returnConfig() peer.LocalConfig { OnAddr := func(p *peer.Peer, msg *payload.AddrMessage) {} OnHeader := func(p *peer.Peer, msg *payload.HeadersMessage) {} - OnGetHeaders := func(msg *payload.GetHeadersMessage) {} + OnGetHeaders := func(p *peer.Peer, msg *payload.GetHeadersMessage) {} OnInv := func(p *peer.Peer, msg *payload.InvMessage) {} - OnGetData := func(msg *payload.GetDataMessage) {} + OnGetData := func(p *peer.Peer, msg *payload.GetDataMessage) {} OnBlock := func(p *peer.Peer, msg *payload.BlockMessage) {} - OnGetBlocks := func(msg *payload.GetBlocksMessage) {} + OnGetBlocks := func(p *peer.Peer, msg *payload.GetBlocksMessage) {} return peer.LocalConfig{ Net: protocol.MainNet, @@ -157,17 +156,9 @@ func TestConfigurations(t *testing.T) { assert.Equal(t, config.Services, p.Services()) - assert.Equal(t, config.UserAgent, p.UserAgent()) - assert.Equal(t, config.Relay, p.CanRelay()) assert.WithinDuration(t, time.Now(), p.CreatedAt(), 1*time.Second) - -} - -func TestHandshakeCancelled(t *testing.T) { - // These are the conditions which should invalidate the handshake. - // Make sure peer is disconnected. } func TestPeerDisconnect(t *testing.T) { @@ -178,21 +169,17 @@ func TestPeerDisconnect(t *testing.T) { inbound := true config := returnConfig() p := peer.NewPeer(conn, inbound, config) - fmt.Println("Calling disconnect") + p.Disconnect() - fmt.Println("Disconnect finished calling") - verack, _ := payload.NewVerackMessage() + verack, err := payload.NewVerackMessage() + assert.Nil(t, err) - fmt.Println(" We good here") + err = p.Write(verack) + assert.NotNil(t, err) - err := p.Write(verack) - - assert.NotEqual(t, err, nil) - - // Check if Stall detector is still running + // Check if stall detector is still running _, ok := <-p.Detector.Quitch assert.Equal(t, ok, false) - } func TestNotifyDisconnect(t *testing.T) { diff --git a/pkg/peer/responsehandlers.go b/pkg/peer/responsehandlers.go new file mode 100644 index 000000000..303ee0759 --- /dev/null +++ b/pkg/peer/responsehandlers.go @@ -0,0 +1,111 @@ +package peer + +import ( + "errors" + "time" + + "github.com/CityOfZion/neo-go/pkg/wire/payload" +) + +// OnGetData is called when a GetData message is received +func (p *Peer) OnGetData(msg *payload.GetDataMessage) { + p.inch <- func() { + if p.config.OnInv != nil { + p.config.OnGetData(p, msg) + } + } +} + +//OnTX is called when a TX message is received +func (p *Peer) OnTX(msg *payload.TXMessage) { + p.inch <- func() { + p.inch <- func() { + if p.config.OnTx != nil { + p.config.OnTx(p, msg) + } + } + } +} + +// OnInv is called when a Inv message is received +func (p *Peer) OnInv(msg *payload.InvMessage) { + p.inch <- func() { + if p.config.OnInv != nil { + p.config.OnInv(p, msg) + } + } +} + +// OnGetHeaders is called when a GetHeaders message is received +func (p *Peer) OnGetHeaders(msg *payload.GetHeadersMessage) { + p.inch <- func() { + if p.config.OnGetHeaders != nil { + p.config.OnGetHeaders(p, msg) + } + } +} + +// OnAddr is called when a Addr message is received +func (p *Peer) OnAddr(msg *payload.AddrMessage) { + p.inch <- func() { + if p.config.OnAddr != nil { + p.config.OnAddr(p, msg) + } + } +} + +// OnGetAddr is called when a GetAddr message is received +func (p *Peer) OnGetAddr(msg *payload.GetAddrMessage) { + p.inch <- func() { + if p.config.OnGetAddr != nil { + p.config.OnGetAddr(p, msg) + } + } +} + +// OnGetBlocks is called when a GetBlocks message is received +func (p *Peer) OnGetBlocks(msg *payload.GetBlocksMessage) { + p.inch <- func() { + if p.config.OnGetBlocks != nil { + p.config.OnGetBlocks(p, msg) + } + } +} + +// OnBlocks is called when a Blocks message is received +func (p *Peer) OnBlocks(msg *payload.BlockMessage) { + p.Detector.RemoveMessage(msg.Command()) + p.inch <- func() { + if p.config.OnBlock != nil { + p.config.OnBlock(p, msg) + } + } +} + +// OnHeaders is called when a Headers message is received +func (p *Peer) OnHeaders(msg *payload.HeadersMessage) { + p.Detector.RemoveMessage(msg.Command()) + p.inch <- func() { + if p.config.OnHeader != nil { + p.config.OnHeader(p, msg) + } + } +} + +// OnVersion Listener will be called +// during the handshake, any error checking should be done here for the versionMessage. +// This should only ever be called during the handshake. Any other place and the peer will disconnect. +func (p *Peer) OnVersion(msg *payload.VersionMessage) error { + if msg.Nonce == p.config.Nonce { + p.conn.Close() + return errors.New("self connection, disconnecting Peer") + } + p.versionKnown = true + p.port = msg.Port + p.services = msg.Services + p.userAgent = string(msg.UserAgent) + p.createdAt = time.Now() + p.relay = msg.Relay + p.startHeight = msg.StartHeight + return nil +} diff --git a/pkg/peer/stall/stall.go b/pkg/peer/stall/stall.go index fc19891fa..e69bcb442 100644 --- a/pkg/peer/stall/stall.go +++ b/pkg/peer/stall/stall.go @@ -61,6 +61,7 @@ func (d *Detector) loop() { d.lock.RUnlock() for _, deadline := range resp { if now.After(deadline) { + fmt.Println(resp) fmt.Println("Deadline passed") return } @@ -99,7 +100,7 @@ func (d *Detector) AddMessage(cmd command.Type) { // peer. This will remove the pendingresponse message from the map. // The command passed through is the command we received func (d *Detector) RemoveMessage(cmd command.Type) { - cmds := d.addMessage(cmd) + cmds := d.removeMessage(cmd) d.lock.Lock() for _, cmd := range cmds { delete(d.responses, cmd) @@ -137,10 +138,8 @@ func (d *Detector) addMessage(cmd command.Type) []command.Type { case command.GetAddr: // We now will expect a Headers Message cmds = append(cmds, command.Addr) - case command.GetData: // We will now expect a block/tx message - // We can optimise this by including the exact inventory type, however it is not needed cmds = append(cmds, command.Block) cmds = append(cmds, command.TX) case command.GetBlocks: @@ -159,19 +158,18 @@ func (d *Detector) removeMessage(cmd command.Type) []command.Type { switch cmd { case command.Block: - // We will now expect a block/tx message + // We will now remove a block and tx message cmds = append(cmds, command.Block) cmds = append(cmds, command.TX) case command.TX: - // We will now expect a block/tx message + // We will now remove a block and tx message cmds = append(cmds, command.Block) cmds = append(cmds, command.TX) - case command.GetBlocks: - // we will now expect a inv message - cmds = append(cmds, command.Inv) - default: + case command.Verack: // We will now expect a verack cmds = append(cmds, cmd) + default: + cmds = append(cmds, cmd) } return cmds } diff --git a/pkg/peer/stall/stall_test.go b/pkg/peer/stall/stall_test.go index 4d5494e12..83de4e3a0 100644 --- a/pkg/peer/stall/stall_test.go +++ b/pkg/peer/stall/stall_test.go @@ -22,7 +22,7 @@ func TestAddRemoveMessage(t *testing.T) { assert.Equal(t, 1, len(mp)) assert.IsType(t, time.Time{}, mp[command.GetAddr]) - d.RemoveMessage(command.GetAddr) + d.RemoveMessage(command.Addr) mp = d.GetMessages() assert.Equal(t, 0, len(mp))