diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index f37e8441e..303701ae5 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -357,8 +357,7 @@ func (bc *Blockchain) persistBlock(block *Block) error { Email: t.Email, Description: t.Description, } - - fmt.Printf("%+v", contract) + _ = contract case *transaction.InvocationTX: } @@ -430,6 +429,15 @@ func (bc *Blockchain) persist(ctx context.Context) (err error) { "blockHeight": bc.BlockHeight(), "took": time.Since(start), }).Info("blockchain persist completed") + } else { + // So we have some blocks in cache but can't persist them? + // Either there are some stale blocks there or the other way + // around (which was seen in practice) --- there are some fresh + // blocks that we can't persist yet. Some of the latter can be useful + // or can be bogus (higher than the header height we expect at + // the moment). So try to reap oldies and strange newbies, if + // there are any. + bc.blockCache.ReapStrangeBlocks(bc.BlockHeight(), bc.HeaderHeight()) } return diff --git a/pkg/core/cache.go b/pkg/core/cache.go index b2bdec704..c2141851e 100644 --- a/pkg/core/cache.go +++ b/pkg/core/cache.go @@ -71,3 +71,17 @@ func (c *Cache) Delete(h util.Uint256) { defer c.lock.Unlock() delete(c.m, h) } + +// ReapStrangeBlocks drops blocks from cache that don't fit into the +// blkHeight-headHeight interval. Cache should only contain blocks that we +// expect to get and store. +func (c *Cache) ReapStrangeBlocks(blkHeight, headHeight uint32) { + c.lock.Lock() + defer c.lock.Unlock() + for i, b := range c.m { + block, ok := b.(*Block) + if ok && (block.Index < blkHeight || block.Index > headHeight) { + delete(c.m, i) + } + } +} diff --git a/pkg/network/discovery.go b/pkg/network/discovery.go index f8edc7dd8..cbc6464a5 100644 --- a/pkg/network/discovery.go +++ b/pkg/network/discovery.go @@ -6,6 +6,7 @@ import ( const ( maxPoolSize = 200 + connRetries = 3 ) // Discoverer is an interface that is responsible for maintaining @@ -15,22 +16,28 @@ type Discoverer interface { PoolCount() int RequestRemote(int) RegisterBadAddr(string) + RegisterGoodAddr(string) + UnregisterConnectedAddr(string) UnconnectedPeers() []string BadPeers() []string + GoodPeers() []string } // DefaultDiscovery default implementation of the Discoverer interface. type DefaultDiscovery struct { transport Transporter dialTimeout time.Duration - addrs map[string]bool badAddrs map[string]bool - unconnectedAddrs map[string]bool + connectedAddrs map[string]bool + goodAddrs map[string]bool + unconnectedAddrs map[string]int requestCh chan int connectedCh chan string backFill chan string badAddrCh chan string pool chan string + goodCh chan string + unconnectedCh chan string } // NewDefaultDiscovery returns a new DefaultDiscovery. @@ -38,11 +45,14 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery { d := &DefaultDiscovery{ transport: ts, dialTimeout: dt, - addrs: make(map[string]bool), badAddrs: make(map[string]bool), - unconnectedAddrs: make(map[string]bool), + connectedAddrs: make(map[string]bool), + goodAddrs: make(map[string]bool), + unconnectedAddrs: make(map[string]int), requestCh: make(chan int), connectedCh: make(chan string), + goodCh: make(chan string), + unconnectedCh: make(chan string), backFill: make(chan string), badAddrCh: make(chan string), pool: make(chan string, maxPoolSize), @@ -54,9 +64,6 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery { // BackFill implements the Discoverer interface and will backfill the // the pool with the given addresses. func (d *DefaultDiscovery) BackFill(addrs ...string) { - if len(d.pool) == maxPoolSize { - return - } for _, addr := range addrs { d.backFill <- addr } @@ -67,6 +74,17 @@ func (d *DefaultDiscovery) PoolCount() int { return len(d.pool) } +// pushToPoolOrDrop tries to push address given into the pool, but if the pool +// is already full, it just drops it +func (d *DefaultDiscovery) pushToPoolOrDrop(addr string) { + select { + case d.pool <- addr: + // ok, queued + default: + // whatever + } +} + // RequestRemote will try to establish a connection with n nodes. func (d *DefaultDiscovery) RequestRemote(n int) { d.requestCh <- n @@ -96,57 +114,87 @@ func (d *DefaultDiscovery) BadPeers() []string { return addrs } -func (d *DefaultDiscovery) work(addrCh chan string) { - for { - addr := <-addrCh - if err := d.transport.Dial(addr, d.dialTimeout); err != nil { - d.badAddrCh <- addr - } else { - d.connectedCh <- addr - } +// GoodPeers returns all addresses of known good peers (that at least once +// succeded handshaking with us). +func (d *DefaultDiscovery) GoodPeers() []string { + addrs := make([]string, 0, len(d.goodAddrs)) + for addr := range d.goodAddrs { + addrs = append(addrs, addr) + } + return addrs +} + +// RegisterGoodAddr registers good known connected address that passed +// handshake successfuly. +func (d *DefaultDiscovery) RegisterGoodAddr(s string) { + d.goodCh <- s +} + +// UnregisterConnectedAddr tells discoverer that this address is no longer +// connected, but it still is considered as good one. +func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) { + d.unconnectedCh <- s +} + +func (d *DefaultDiscovery) tryAddress(addr string) { + if err := d.transport.Dial(addr, d.dialTimeout); err != nil { + d.badAddrCh <- addr + } else { + d.connectedCh <- addr } } -func (d *DefaultDiscovery) next() string { - return <-d.pool +func (d *DefaultDiscovery) requestToWork() { + var requested int + + for { + for requested = <-d.requestCh; requested > 0; requested-- { + select { + case r := <-d.requestCh: + if requested < r { + requested = r + } + case addr := <-d.pool: + if !d.connectedAddrs[addr] { + go d.tryAddress(addr) + } + } + } + } } func (d *DefaultDiscovery) run() { - var ( - maxWorkers = 5 - workCh = make(chan string) - ) - - for i := 0; i < maxWorkers; i++ { - go d.work(workCh) - } - + go d.requestToWork() for { select { case addr := <-d.backFill: - if _, ok := d.badAddrs[addr]; ok { + if d.badAddrs[addr] || d.connectedAddrs[addr] || + d.unconnectedAddrs[addr] > 0 { break } - if _, ok := d.addrs[addr]; !ok { - d.addrs[addr] = true - d.unconnectedAddrs[addr] = true - d.pool <- addr - } - case n := <-d.requestCh: - go func() { - for i := 0; i < n; i++ { - workCh <- d.next() - } - }() + d.unconnectedAddrs[addr] = connRetries + d.pushToPoolOrDrop(addr) case addr := <-d.badAddrCh: - d.badAddrs[addr] = true - delete(d.unconnectedAddrs, addr) - go func() { - workCh <- d.next() - }() + d.unconnectedAddrs[addr]-- + if d.unconnectedAddrs[addr] > 0 { + d.pushToPoolOrDrop(addr) + } else { + d.badAddrs[addr] = true + delete(d.unconnectedAddrs, addr) + } + d.RequestRemote(1) case addr := <-d.connectedCh: delete(d.unconnectedAddrs, addr) + if !d.connectedAddrs[addr] { + d.connectedAddrs[addr] = true + } + case addr := <-d.goodCh: + if !d.goodAddrs[addr] { + d.goodAddrs[addr] = true + } + case addr := <-d.unconnectedCh: + delete(d.connectedAddrs, addr) } } } diff --git a/pkg/network/handshakestage_string.go b/pkg/network/handshakestage_string.go new file mode 100644 index 000000000..40cac11bb --- /dev/null +++ b/pkg/network/handshakestage_string.go @@ -0,0 +1,27 @@ +// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT. + +package network + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[nothingDone-0] + _ = x[versionSent-1] + _ = x[versionReceived-2] + _ = x[verAckSent-3] + _ = x[verAckReceived-4] +} + +const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived" + +var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61} + +func (i handShakeStage) String() string { + if i >= handShakeStage(len(_handShakeStage_index)-1) { + return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]] +} diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index c231f96dd..bcbdff5df 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -93,12 +93,15 @@ func (chain testChain) Verify(*transaction.Transaction) error { type testDiscovery struct{} -func (d testDiscovery) BackFill(addrs ...string) {} -func (d testDiscovery) PoolCount() int { return 0 } -func (d testDiscovery) RegisterBadAddr(string) {} -func (d testDiscovery) UnconnectedPeers() []string { return []string{} } -func (d testDiscovery) RequestRemote(n int) {} -func (d testDiscovery) BadPeers() []string { return []string{} } +func (d testDiscovery) BackFill(addrs ...string) {} +func (d testDiscovery) PoolCount() int { return 0 } +func (d testDiscovery) RegisterBadAddr(string) {} +func (d testDiscovery) RegisterGoodAddr(string) {} +func (d testDiscovery) UnregisterConnectedAddr(string) {} +func (d testDiscovery) UnconnectedPeers() []string { return []string{} } +func (d testDiscovery) RequestRemote(n int) {} +func (d testDiscovery) BadPeers() []string { return []string{} } +func (d testDiscovery) GoodPeers() []string { return []string{} } type localTransport struct{} @@ -114,6 +117,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {} type localPeer struct { netaddr net.TCPAddr version *payload.Version + handshaked bool t *testing.T messageHandler func(t *testing.T, msg *Message) } @@ -142,8 +146,23 @@ func (p *localPeer) Done() chan error { func (p *localPeer) Version() *payload.Version { return p.version } -func (p *localPeer) SetVersion(v *payload.Version) { +func (p *localPeer) HandleVersion(v *payload.Version) error { p.version = v + return nil +} +func (p *localPeer) SendVersion(m *Message) error { + return p.WriteMsg(m) +} +func (p *localPeer) SendVersionAck(m *Message) error { + return p.WriteMsg(m) +} +func (p *localPeer) HandleVersionAck() error { + p.handshaked = true + return nil +} + +func (p *localPeer) Handshaked() bool { + return p.handshaked } func newTestServer() *Server { diff --git a/pkg/network/payload/address.go b/pkg/network/payload/address.go index c891753e5..a83498a7d 100644 --- a/pkg/network/payload/address.go +++ b/pkg/network/payload/address.go @@ -3,6 +3,7 @@ package payload import ( "io" "net" + "strconv" "time" "github.com/CityOfZion/neo-go/pkg/util" @@ -47,11 +48,28 @@ func (p *AddressAndTime) EncodeBinary(w io.Writer) error { return bw.Err } +// IPPortString makes a string from IP and port specified. +func (p *AddressAndTime) IPPortString() string { + var netip net.IP = make(net.IP, 16) + + copy(netip, p.IP[:]) + port := strconv.Itoa(int(p.Port)) + return netip.String() + ":" + port +} + // AddressList is a list with AddrAndTime. type AddressList struct { Addrs []*AddressAndTime } +// NewAddressList creates a list for n AddressAndTime elements. +func NewAddressList(n int) *AddressList { + alist := AddressList{ + Addrs: make([]*AddressAndTime, n), + } + return &alist +} + // DecodeBinary implements the Payload interface. func (p *AddressList) DecodeBinary(r io.Reader) error { br := util.BinReader{R: r} diff --git a/pkg/network/payload/address_test.go b/pkg/network/payload/address_test.go index 154bf017d..afece415a 100644 --- a/pkg/network/payload/address_test.go +++ b/pkg/network/payload/address_test.go @@ -35,7 +35,7 @@ func TestEncodeDecodeAddress(t *testing.T) { func TestEncodeDecodeAddressList(t *testing.T) { var lenList uint8 = 4 - addrList := &AddressList{make([]*AddressAndTime, lenList)} + addrList := NewAddressList(int(lenList)) for i := 0; i < int(lenList); i++ { e, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:200%d", i)) addrList.Addrs[i] = NewAddressAndTime(e, time.Now()) diff --git a/pkg/network/payload/payload.go b/pkg/network/payload/payload.go index 90d5b0d74..e23f9955a 100644 --- a/pkg/network/payload/payload.go +++ b/pkg/network/payload/payload.go @@ -7,3 +7,22 @@ type Payload interface { EncodeBinary(io.Writer) error DecodeBinary(io.Reader) error } + +// NullPayload is a dummy payload with no fields. +type NullPayload struct { +} + +// NewNullPayload returns zero-sized stub payload. +func NewNullPayload() *NullPayload { + return &NullPayload{} +} + +// DecodeBinary implements the Payload interface. +func (p *NullPayload) DecodeBinary(r io.Reader) error { + return nil +} + +// EncodeBinary implements the Payload interface. +func (p *NullPayload) EncodeBinary(r io.Writer) error { + return nil +} diff --git a/pkg/network/peer.go b/pkg/network/peer.go index 1298d615a..620aa3d91 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -13,5 +13,9 @@ type Peer interface { WriteMsg(msg *Message) error Done() chan error Version() *payload.Version - SetVersion(*payload.Version) + Handshaked() bool + SendVersion(*Message) error + SendVersionAck(*Message) error + HandleVersion(*payload.Version) error + HandleVersionAck() error } diff --git a/pkg/network/server.go b/pkg/network/server.go index 84d0f7ab3..1de1b060d 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math/rand" + "net" "sync" "time" @@ -15,13 +16,15 @@ import ( ) const ( - minPeers = 5 - maxBlockBatch = 200 - minPoolCount = 30 + // peer numbers are arbitrary at the moment + minPeers = 5 + maxPeers = 20 + maxBlockBatch = 200 + maxAddrsToSend = 200 + minPoolCount = 30 ) var ( - errPortMismatch = errors.New("port mismatch") errIdenticalID = errors.New("identical node id") errInvalidHandshake = errors.New("invalid handshake") errInvalidNetwork = errors.New("invalid network") @@ -46,6 +49,7 @@ type ( lock sync.RWMutex peers map[Peer]bool + addrReq chan *Message register chan Peer unregister chan peerDrop quit chan struct{} @@ -64,6 +68,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer) *Server { chain: chain, id: rand.Uint32(), quit: make(chan struct{}), + addrReq: make(chan *Message, minPeers), register: make(chan Peer), unregister: make(chan peerDrop), peers: make(map[Peer]bool), @@ -90,12 +95,7 @@ func (s *Server) Start(errChan chan error) { "headerHeight": s.chain.HeaderHeight(), }).Info("node started") - for _, addr := range s.Seeds { - if err := s.transport.Dial(addr, s.DialTimeout); err != nil { - log.Warnf("failed to connect to remote node %s", addr) - continue - } - } + s.discovery.BackFill(s.Seeds...) go s.transport.Accept() s.run() @@ -122,6 +122,19 @@ func (s *Server) BadPeers() []string { func (s *Server) run() { for { + c := s.PeerCount() + if c < minPeers { + s.discovery.RequestRemote(maxPeers - c) + } + if s.discovery.PoolCount() < minPoolCount { + select { + case s.addrReq <- NewMessage(s.Net, CMDGetAddr, payload.NewNullPayload()): + // sent request + default: + // we have one in the queue already that is + // gonna be served by some worker when it's ready + } + } select { case <-s.quit: s.transport.Close() @@ -141,12 +154,19 @@ func (s *Server) run() { "addr": p.NetAddr(), }).Info("new peer connected") case drop := <-s.unregister: - delete(s.peers, drop.peer) - log.WithFields(log.Fields{ - "addr": drop.peer.NetAddr(), - "reason": drop.reason, - "peerCount": s.PeerCount(), - }).Warn("peer disconnected") + if s.peers[drop.peer] { + delete(s.peers, drop.peer) + log.WithFields(log.Fields{ + "addr": drop.peer.NetAddr(), + "reason": drop.reason, + "peerCount": s.PeerCount(), + }).Warn("peer disconnected") + addr := drop.peer.NetAddr().String() + s.discovery.UnregisterConnectedAddr(addr) + s.discovery.BackFill(addr) + } + // else the peer is already gone, which can happen + // because we have two goroutines sending signals here } } } @@ -174,20 +194,34 @@ func (s *Server) startProtocol(p Peer) { "id": p.Version().Nonce, }).Info("started protocol") - s.requestHeaders(p) + s.discovery.RegisterGoodAddr(p.NetAddr().String()) + err := s.requestHeaders(p) + if err != nil { + p.Disconnect(err) + return + } timer := time.NewTimer(s.ProtoTickInterval) for { select { - case err := <-p.Done(): - s.unregister <- peerDrop{p, err} - return + case err = <-p.Done(): + // time to stop + case m := <-s.addrReq: + err = p.WriteMsg(m) case <-timer.C: // Try to sync in headers and block with the peer if his block height is higher then ours. if p.Version().StartHeight > s.chain.BlockHeight() { - s.requestBlocks(p) + err = s.requestBlocks(p) } - timer.Reset(s.ProtoTickInterval) + if err == nil { + timer.Reset(s.ProtoTickInterval) + } + } + if err != nil { + s.unregister <- peerDrop{p, err} + timer.Stop() + p.Disconnect(err) + return } } } @@ -201,20 +235,23 @@ func (s *Server) sendVersion(p Peer) error { s.chain.BlockHeight(), s.Relay, ) - return p.WriteMsg(NewMessage(s.Net, CMDVersion, payload)) + return p.SendVersion(NewMessage(s.Net, CMDVersion, payload)) } // When a peer sends out his version we reply with verack after validating // the version. func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { - if p.NetAddr().Port != int(version.Port) { - return errPortMismatch + err := p.HandleVersion(version) + if err != nil { + return err } if s.id == version.Nonce { return errIdenticalID } - p.SetVersion(version) - return p.WriteMsg(NewMessage(s.Net, CMDVerack, nil)) + if p.NetAddr().Port != int(version.Port) { + return fmt.Errorf("port mismatch: connected to %d and peer sends %d", p.NetAddr().Port, version.Port) + } + return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil)) } // handleHeadersCmd will process the headers it received from its peer. @@ -251,18 +288,42 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) } +// handleAddrCmd will process received addresses. +func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error { + for _, a := range addrs.Addrs { + s.discovery.BackFill(a.IPPortString()) + } + return nil +} + +// handleGetAddrCmd sends to the peer some good addresses that we know of. +func (s *Server) handleGetAddrCmd(p Peer) error { + addrs := s.discovery.GoodPeers() + if len(addrs) > maxAddrsToSend { + addrs = addrs[:maxAddrsToSend] + } + alist := payload.NewAddressList(len(addrs)) + ts := time.Now() + for i, addr := range addrs { + // we know it's a good address, so it can't fail + netaddr, _ := net.ResolveTCPAddr("tcp", addr) + alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts) + } + return p.WriteMsg(NewMessage(s.Net, CMDAddr, alist)) +} + // requestHeaders will send a getheaders message to the peer. // The peer will respond with headers op to a count of 2000. -func (s *Server) requestHeaders(p Peer) { +func (s *Server) requestHeaders(p Peer) error { start := []util.Uint256{s.chain.CurrentHeaderHash()} payload := payload.NewGetBlocks(start, util.Uint256{}) - p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload)) + return p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload)) } // requestBlocks will send a getdata message to the peer // to sync up in blocks. A maximum of maxBlockBatch will // send at once. -func (s *Server) requestBlocks(p Peer) { +func (s *Server) requestBlocks(p Peer) error { var ( hashes []util.Uint256 hashStart = s.chain.BlockHeight() + 1 @@ -275,10 +336,11 @@ func (s *Server) requestBlocks(p Peer) { } if len(hashes) > 0 { payload := payload.NewInventory(payload.BlockType, hashes) - p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) + return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) } else if s.chain.HeaderHeight() < p.Version().StartHeight { - s.requestHeaders(p) + return s.requestHeaders(p) } + return nil } // handleMessage will process the given message. @@ -289,26 +351,40 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return errInvalidNetwork } - switch msg.CommandType() { - case CMDVersion: - version := msg.Payload.(*payload.Version) - return s.handleVersionCmd(peer, version) - case CMDHeaders: - headers := msg.Payload.(*payload.Headers) - go s.handleHeadersCmd(peer, headers) - case CMDInv: - inventory := msg.Payload.(*payload.Inventory) - return s.handleInvCmd(peer, inventory) - case CMDBlock: - block := msg.Payload.(*core.Block) - return s.handleBlockCmd(peer, block) - case CMDVerack: - // Make sure this peer has send his version before we start the - // protocol with that peer. - if peer.Version() == nil { - return errInvalidHandshake + if peer.Handshaked() { + switch msg.CommandType() { + case CMDAddr: + addrs := msg.Payload.(*payload.AddressList) + return s.handleAddrCmd(peer, addrs) + case CMDGetAddr: + // it has no payload + return s.handleGetAddrCmd(peer) + case CMDHeaders: + headers := msg.Payload.(*payload.Headers) + go s.handleHeadersCmd(peer, headers) + case CMDInv: + inventory := msg.Payload.(*payload.Inventory) + return s.handleInvCmd(peer, inventory) + case CMDBlock: + block := msg.Payload.(*core.Block) + return s.handleBlockCmd(peer, block) + case CMDVersion, CMDVerack: + return fmt.Errorf("received '%s' after the handshake", msg.CommandType()) + } + } else { + switch msg.CommandType() { + case CMDVersion: + version := msg.Payload.(*payload.Version) + return s.handleVersionCmd(peer, version) + case CMDVerack: + err := peer.HandleVersionAck() + if err != nil { + return err + } + go s.startProtocol(peer) + default: + return fmt.Errorf("received '%s' during handshake", msg.CommandType()) } - go s.startProtocol(peer) } return nil } diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 0d4e1238b..7d96b7989 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -71,7 +71,7 @@ func TestServerNotSendsVerack(t *testing.T) { version := payload.NewVersion(1337, 2000, "/NEO-GO/", 0, true) err := s.handleVersionCmd(p, version) assert.NotNil(t, err) - assert.Equal(t, errPortMismatch, err) + assert.Contains(t, err.Error(), "port mismatch") // identical id's version = payload.NewVersion(1, 3000, "/NEO-GO/", 0, true) diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index e9e81ffa9..10dac0a82 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -1,12 +1,29 @@ package network import ( + "errors" + "fmt" "net" "sync" "github.com/CityOfZion/neo-go/pkg/network/payload" ) +type handShakeStage uint8 + +//go:generate stringer -type=handShakeStage +const ( + nothingDone handShakeStage = 0 + versionSent handShakeStage = 1 + versionReceived handShakeStage = 2 + verAckSent handShakeStage = 3 + verAckReceived handShakeStage = 4 +) + +var ( + errStateMismatch = errors.New("tried to send protocol message before handshake completed") +) + // TCPPeer represents a connected remote node in the // network over TCP. type TCPPeer struct { @@ -17,6 +34,8 @@ type TCPPeer struct { // The version of the peer. version *payload.Version + handShake handShakeStage + done chan error wg sync.WaitGroup @@ -35,8 +54,16 @@ func NewTCPPeer(conn net.Conn) *TCPPeer { } // WriteMsg implements the Peer interface. This will write/encode the message -// to the underlying connection. +// to the underlying connection, this only works for messages other than Version +// or VerAck. func (p *TCPPeer) WriteMsg(msg *Message) error { + if !p.Handshaked() { + return errStateMismatch + } + return p.writeMsg(msg) +} + +func (p *TCPPeer) writeMsg(msg *Message) error { select { case err := <-p.done: return err @@ -45,6 +72,55 @@ func (p *TCPPeer) WriteMsg(msg *Message) error { } } +// Handshaked returns status of the handshake, whether it's completed or not. +func (p *TCPPeer) Handshaked() bool { + return p.handShake == verAckReceived +} + +// SendVersion checks for the handshake state and sends a message to the peer. +func (p *TCPPeer) SendVersion(msg *Message) error { + if p.handShake != nothingDone { + return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String()) + } + err := p.writeMsg(msg) + if err == nil { + p.handShake = versionSent + } + return err +} + +// HandleVersion checks for the handshake state and version message contents. +func (p *TCPPeer) HandleVersion(version *payload.Version) error { + if p.handShake != versionSent { + return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String()) + } + p.version = version + p.handShake = versionReceived + return nil +} + +// SendVersionAck checks for the handshake state and sends a message to the peer. +func (p *TCPPeer) SendVersionAck(msg *Message) error { + if p.handShake != versionReceived { + return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String()) + } + err := p.writeMsg(msg) + if err == nil { + p.handShake = verAckSent + } + return err +} + +// HandleVersionAck checks handshake sequence correctness when VerAck message +// is received. +func (p *TCPPeer) HandleVersionAck() error { + if p.handShake != verAckSent { + return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String()) + } + p.handShake = verAckReceived + return nil +} + // NetAddr implements the Peer interface. func (p *TCPPeer) NetAddr() *net.TCPAddr { return &p.addr @@ -59,15 +135,16 @@ func (p *TCPPeer) Done() chan error { // Disconnect will fill the peer's done channel with the given error. func (p *TCPPeer) Disconnect(err error) { - p.done <- err + p.conn.Close() + select { + case p.done <- err: + // one message to the queue + default: + // the other side may already be gone, it's OK + } } // Version implements the Peer interface. func (p *TCPPeer) Version() *payload.Version { return p.version } - -// SetVersion implements the Peer interface. -func (p *TCPPeer) SetVersion(v *payload.Version) { - p.version = v -} diff --git a/pkg/network/tcp_transport.go b/pkg/network/tcp_transport.go index 82ea5f351..8f61b7607 100644 --- a/pkg/network/tcp_transport.go +++ b/pkg/network/tcp_transport.go @@ -75,21 +75,19 @@ func (t *TCPTransport) handleConn(conn net.Conn) { err error ) - defer func() { - p.Disconnect(err) - }() - t.server.register <- p for { msg := &Message{} if err = msg.Decode(p.conn); err != nil { - return + break } if err = t.server.handleMessage(p, msg); err != nil { - return + break } } + t.server.unregister <- peerDrop{p, err} + p.Disconnect(err) } // Close implements the Transporter interface.