diff --git a/pkg/network/addr_payload.go b/pkg/network/addr_payload.go deleted file mode 100644 index cb2c61098..000000000 --- a/pkg/network/addr_payload.go +++ /dev/null @@ -1,25 +0,0 @@ -package network - -import "net" - -// AddrWithTimestamp payload. -type AddrWithTimestamp struct { - t uint32 - services uint64 - endpoint net.Addr -} - -func newAddrWithTimestampFromPeer(p *Peer) AddrWithTimestamp { - return AddrWithTimestamp{ - t: 1223345, - services: 1, - endpoint: p.conn.RemoteAddr(), - } -} - -// AddrPayload container a list of known peer addresses. -type AddrPayload []AddrWithTimestamp - -func (p AddrPayload) encode() ([]byte, error) { - return nil, nil -} diff --git a/pkg/network/message.go b/pkg/network/message.go index fd4ea61c7..4589e64c5 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/binary" "errors" + "fmt" "io" "github.com/anthdm/neo-go/pkg/network/payload" @@ -92,23 +93,6 @@ func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message { } } -func TeeWriter(w io.Writer, r io.Reader) io.Writer { - return &teeWriter{w, r} -} - -type teeWriter struct { - w io.Writer - r io.Reader -} - -func (w *teeWriter) Write(b []byte) (n int, err error) { - n, err = w.w.Write(b) - if n > 0 { - n, err = w.r.Read(b[:n]) - } - return -} - // Converts the 12 byte command slice to a commandType. func (m *Message) commandType() commandType { cmd := string(bytes.TrimRight(m.Command, "\x00")) @@ -153,11 +137,16 @@ func (m *Message) decode(r io.Reader) error { m.Length = binary.LittleEndian.Uint32(buf[16:20]) m.Checksum = binary.LittleEndian.Uint32(buf[20:24]) - // payload is 0, so dont decode it. - if m.Length == 0 { + // if their is no payload. + if m.Length == 0 || !needPayloadDecode(m.commandType()) { return nil } + return m.decodePayload(r) +} + +func (m *Message) decodePayload(r io.Reader) error { + // write to a buffer what we read to calculate the checksum. buffer := new(bytes.Buffer) tr := io.TeeReader(r, buffer) var p payload.Payloader @@ -168,6 +157,8 @@ func (m *Message) decode(r io.Reader) error { if err := p.Decode(tr); err != nil { return err } + default: + return fmt.Errorf("unknown command to decode: %s", m.commandType()) } // Compare the checksum of the payload. @@ -186,14 +177,18 @@ func (m *Message) encode(w io.Writer) error { pbuf := new(bytes.Buffer) // if there is a payload fill its allocated buffer. + var checksum []byte if m.Payload != nil { if err := m.Payload.Encode(pbuf); err != nil { return err } - checksum := sumSHA256(sumSHA256(pbuf.Bytes()))[:4] - m.Checksum = binary.LittleEndian.Uint32(checksum) + checksum = sumSHA256(sumSHA256(pbuf.Bytes()))[:4] + } else { + checksum = sumSHA256(sumSHA256([]byte{}))[:4] } + m.Checksum = binary.LittleEndian.Uint32(checksum) + // fill the message buffer binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic)) copy(buf[4:16], m.Command) @@ -243,6 +238,10 @@ func cmdToByteSlice(cmd commandType) []byte { return b } +func needPayloadDecode(cmd commandType) bool { + return cmd != cmdVerack && cmd != cmdGetAddr +} + func sumSHA256(b []byte) []byte { h := sha256.New() h.Write(b) diff --git a/pkg/network/message_test.go b/pkg/network/message_test.go index c4ce6072c..35ee154a5 100644 --- a/pkg/network/message_test.go +++ b/pkg/network/message_test.go @@ -68,20 +68,20 @@ func TestMessageEncodeDecodeWithVersion(t *testing.T) { t.Log(p1) } -// func TestMessageInvalidChecksum(t *testing.T) { -// m := newMessage(ModeTestNet, cmdVersion, []byte{}) -// m.Checksum = 1337 +func TestMessageInvalidChecksum(t *testing.T) { + m := newMessage(ModeTestNet, cmdVersion, nil) + m.Checksum = 1337 -// buf := &bytes.Buffer{} -// if err := m.encode(buf); err != nil { -// t.Error(err) -// } + buf := &bytes.Buffer{} + if err := m.encode(buf); err != nil { + t.Error(err) + } -// md := &Message{} -// if err := md.decode(buf); err == nil { -// t.Error("decode should failed with checkum mismatch error") -// } -// } + md := &Message{} + if err := md.decode(buf); err == nil { + t.Error("decode should failed with checkum mismatch error") + } +} // func TestNewVersionPayload(t *testing.T) { // ua := "/neo/0.0.1/" diff --git a/pkg/network/payload/payloader.go b/pkg/network/payload/payloader.go index 21e680524..015b43066 100644 --- a/pkg/network/payload/payloader.go +++ b/pkg/network/payload/payloader.go @@ -2,18 +2,10 @@ package payload import "io" -// Nothing is a safe non payload. -var Nothing = nothing{} - -// Payloader .. +// Payloader is anything that can be binary encoded and decoded. +// Every payload used in messages need to satisfy the Payloader interface. type Payloader interface { Encode(io.Writer) error Decode(io.Reader) error Size() uint32 } - -type nothing struct{} - -func (p nothing) Encode(w io.Writer) error { return nil } -func (p nothing) Decode(R io.Reader) error { return nil } -func (p nothing) Size() uint32 { return 0 } diff --git a/pkg/network/payload/version.go b/pkg/network/payload/version.go index 88ddb0b43..526dde118 100644 --- a/pkg/network/payload/version.go +++ b/pkg/network/payload/version.go @@ -1,12 +1,14 @@ package payload import ( - "bytes" "encoding/binary" "io" ) -const minVersionSize = 27 +const ( + lenUA = 12 + minVersionSize = 27 + lenUA +) // Version payload. type Version struct { @@ -20,7 +22,7 @@ type Version struct { Port uint16 // it's used to distinguish the node from public IP Nonce uint32 - // client id currently 6 bytes \v/NEO:2.6.0/ + // client id currently 12 bytes \v/NEO:2.6.0/ UserAgent []byte // Height of the block chain StartHeight uint32 @@ -42,20 +44,19 @@ func NewVersion(p uint16, ua string, h uint32, r bool) *Version { } } -// Size .. +// Size implements the Payloader interface. func (p *Version) Size() uint32 { - n := minVersionSize + len(p.UserAgent) + n := minVersionSize return uint32(n) } -// Decode .. +// Decode implements the Payloader interface. func (p *Version) Decode(r io.Reader) error { - buf := new(bytes.Buffer) - if _, err := buf.ReadFrom(r); err != nil { + b := make([]byte, minVersionSize) + if _, err := r.Read(b); err != nil { return err } - b := buf.Bytes() // 27 bytes for the fixed size fields + the length of the user agent // which is kinda variable, according to the docs. lenUA := len(b) - minVersionSize @@ -75,7 +76,7 @@ func (p *Version) Decode(r io.Reader) error { return nil } -// Encode .. +// Encode implements the Payloader interface. func (p *Version) Encode(w io.Writer) error { buf := make([]byte, p.Size()) diff --git a/pkg/network/peer.go b/pkg/network/peer.go index 64c3832ec..68adb4b00 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -38,7 +38,7 @@ func (p *Peer) writeLoop() { for { msg := <-p.send - rpcLogger.Printf("OUT :: %+v", msg) + rpcLogger.Printf("OUT :: %s", msg.commandType()) if err := msg.encode(p.conn); err != nil { log.Printf("encode error: %s", err) } diff --git a/pkg/network/server.go b/pkg/network/server.go index 2ee9d521d..99c411c5f 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -167,11 +167,7 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error { switch msg.commandType() { case cmdVersion: - // v, err := msg.decodePayload() - // if err != nil { - // return err - // } - // return s.handleVersionCmd(v.(*Version), peer) + return s.handleVersionCmd(msg.Payload.(*payload.Version), peer) case cmdVerack: case cmdGetAddr: return s.handleGetAddrCmd(msg, peer) @@ -226,10 +222,10 @@ func (s *Server) handleGetAddrCmd(msg *Message, peer *Peer) error { // if err != nil { // return err // } - var addrList []AddrWithTimestamp - for peer := range s.peers { - addrList = append(addrList, newAddrWithTimestampFromPeer(peer)) - } + // var addrList []AddrWithTimestamp + // for peer := range s.peers { + // addrList = append(addrList, newAddrWithTimestampFromPeer(peer)) + // } return nil }