wip refact 2

This commit is contained in:
anthdm 2018-01-27 16:00:28 +01:00
parent ccaaf07dad
commit 754a473488
9 changed files with 362 additions and 177 deletions

View file

@ -0,0 +1,25 @@
package network
import "net"
// AddrWithTimestamp payload.
type AddrWithTimestamp struct {
t uint32
services uint64
endpoint net.Addr
}
func newAddrWithTimestampFromPeer(p *Peer) AddrWithTimestamp {
return AddrWithTimestamp{
t: 1223345,
services: 1,
endpoint: p.conn.RemoteAddr(),
}
}
// AddrPayload container a list of known peer addresses.
type AddrPayload []AddrWithTimestamp
func (p AddrPayload) encode() ([]byte, error) {
return nil, nil
}

View file

@ -6,6 +6,8 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"github.com/anthdm/neo-go/pkg/network/payload"
) )
const ( const (
@ -57,7 +59,7 @@ type Message struct {
// hash of the payload // hash of the payload
Checksum uint32 Checksum uint32
// Payload send with the message. // Payload send with the message.
Payload []byte Payload payload.Payloader
} }
type commandType string type commandType string
@ -77,17 +79,34 @@ const (
cmdTX = "tx" cmdTX = "tx"
) )
func newMessage(magic NetMode, cmd commandType, payload []byte) *Message { func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message {
sum := sumSHA256(sumSHA256(payload))[:4] var size uint32
sumuint32 := binary.LittleEndian.Uint32(sum) if p != nil {
size = p.Size()
return &Message{
Magic: magic,
Command: cmdToByteSlice(cmd),
Length: uint32(len(payload)),
Checksum: sumuint32,
Payload: payload,
} }
return &Message{
Magic: magic,
Command: cmdToByteSlice(cmd),
Length: size,
Payload: p,
}
}
func TeeWriter(w io.Writer, r io.Reader) io.Writer {
return &teeWriter{w, r}
}
type teeWriter struct {
w io.Writer
r io.Reader
}
func (w *teeWriter) Write(b []byte) (n int, err error) {
n, err = w.w.Write(b)
if n > 0 {
n, err = w.r.Read(b[:n])
}
return
} }
// Converts the 12 byte command slice to a commandType. // Converts the 12 byte command slice to a commandType.
@ -134,140 +153,79 @@ func (m *Message) decode(r io.Reader) error {
m.Length = binary.LittleEndian.Uint32(buf[16:20]) m.Length = binary.LittleEndian.Uint32(buf[16:20])
m.Checksum = binary.LittleEndian.Uint32(buf[20:24]) m.Checksum = binary.LittleEndian.Uint32(buf[20:24])
payload := make([]byte, m.Length) // payload is 0, so dont decode it.
if _, err := r.Read(payload); err != nil { if m.Length == 0 {
return err return nil
}
buffer := new(bytes.Buffer)
tr := io.TeeReader(r, buffer)
var p payload.Payloader
switch m.commandType() {
case cmdVersion:
p = &payload.Version{}
if err := p.Decode(tr); err != nil {
return err
}
} }
// Compare the checksum of the payload. // Compare the checksum of the payload.
if !compareChecksum(m.Checksum, payload) { if !compareChecksum(m.Checksum, buffer.Bytes()) {
return errors.New("checksum mismatch error") return errors.New("checksum mismatch error")
} }
m.Payload = payload m.Payload = p
return nil return nil
} }
// encode a Message to any given io.Writer. // encode a Message to any given io.Writer.
func (m *Message) encode(w io.Writer) error { func (m *Message) encode(w io.Writer) error {
// 24 bytes for the fixed sized fields + the length of the payload. buf := make([]byte, minMessageSize)
buf := make([]byte, minMessageSize+m.Length) pbuf := new(bytes.Buffer)
// if there is a payload fill its allocated buffer.
if m.Payload != nil {
if err := m.Payload.Encode(pbuf); err != nil {
return err
}
checksum := sumSHA256(sumSHA256(pbuf.Bytes()))[:4]
m.Checksum = binary.LittleEndian.Uint32(checksum)
}
// fill the message buffer
binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic)) binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic))
copy(buf[4:16], m.Command) copy(buf[4:16], m.Command)
binary.LittleEndian.PutUint32(buf[16:20], m.Length) binary.LittleEndian.PutUint32(buf[16:20], m.Length)
binary.LittleEndian.PutUint32(buf[20:24], m.Checksum) binary.LittleEndian.PutUint32(buf[20:24], m.Checksum)
copy(buf[24:len(buf)], m.Payload)
_, err := w.Write(buf) // write the message
return err n, err := w.Write(buf)
} if err != nil {
return err
}
func (m *Message) decodePayload() (interface{}, error) { // we need to have at least writen exactly minMessageSize bytes.
switch m.commandType() { if n != minMessageSize {
case cmdVersion: return errors.New("long/short read error when encoding message")
v := &Version{} }
if err := v.decode(m.Payload); err != nil {
return nil, err // write the payload if there was any
if pbuf.Len() > 0 {
n, err = w.Write(pbuf.Bytes())
if err != nil {
return err
}
if uint32(n) != m.Payload.Size() {
return errors.New("long/short read error when encoding payload")
} }
return v, nil
} }
return nil, nil
}
// Version payload description.
//
// Size Field DataType Description
// ---------------------------------------------------------------------------------------------
// 4 Version uint32 Version of protocol, 0 for now
// 8 Services uint64 The service provided by the node is currently 1
// 4 Timestamp uint32 Current time
// 2 Port uint16 Port that the server is listening on, it's 0 if not used.
// 4 Nonce uint32 It's used to distinguish the node from public IP
// ? UserAgent varstr Client ID
// 4 StartHeight uint32 Height of block chain
// 1 Relay bool Whether to receive and forward
type Version struct {
// currently the version of the protocol is 0
Version uint32
// currently 1
Services uint64
// timestamp
Timestamp uint32
// port this server is listening on
Port uint16
// it's used to distinguish the node from public IP
Nonce uint32
// client id
UserAgent []byte // ?
// Height of the block chain
StartHeight uint32
// Whether to receive and forward
Relay bool
}
func newVersionPayload(p uint16, ua string, h uint32, r bool) *Version {
return &Version{
Version: 0,
Services: 1,
Timestamp: 12345,
Port: p,
Nonce: 19110,
UserAgent: []byte(ua),
StartHeight: 0,
Relay: r,
}
}
func (p *Version) decode(b []byte) error {
// Fixed fields have a total of 27 bytes. We substract this size
// with the total buffer length to know the length of the user agent.
lenUA := len(b) - 27
p.Version = binary.LittleEndian.Uint32(b[0:4])
p.Services = binary.LittleEndian.Uint64(b[4:12])
p.Timestamp = binary.LittleEndian.Uint32(b[12:16])
// FIXME: port's byteorder should be big endian according to the docs.
// but when connecting to the privnet docker image it's little endian.
p.Port = binary.LittleEndian.Uint16(b[16:18])
p.Nonce = binary.LittleEndian.Uint32(b[18:22])
p.UserAgent = b[22 : 22+lenUA]
curlen := 22 + lenUA
p.StartHeight = binary.LittleEndian.Uint32(b[curlen : curlen+4])
p.Relay = b[len(b)-1 : len(b)][0] == 1
return nil return nil
} }
func (p *Version) encode() ([]byte, error) {
// 27 bytes for the fixed size fields + the length of the user agent
// which is kinda variable, according to the docs.
buf := make([]byte, 27+len(p.UserAgent))
binary.LittleEndian.PutUint32(buf[0:4], p.Version)
binary.LittleEndian.PutUint64(buf[4:12], p.Services)
binary.LittleEndian.PutUint32(buf[12:16], p.Timestamp)
// FIXME: byte order (little / big)?
binary.LittleEndian.PutUint16(buf[16:18], p.Port)
binary.LittleEndian.PutUint32(buf[18:22], p.Nonce)
copy(buf[22:22+len(p.UserAgent)], p.UserAgent) //
curLen := 22 + len(p.UserAgent)
binary.LittleEndian.PutUint32(buf[curLen:curLen+4], p.StartHeight)
// yikes
var b []byte
if p.Relay {
b = []byte{1}
} else {
b = []byte{0}
}
copy(buf[curLen+4:len(buf)], b)
return buf, nil
}
// convert a command (string) to a byte slice filled with 0 bytes till // convert a command (string) to a byte slice filled with 0 bytes till
// size 12. // size 12.
func cmdToByteSlice(cmd commandType) []byte { func cmdToByteSlice(cmd commandType) []byte {

View file

@ -2,30 +2,31 @@ package network
import ( import (
"bytes" "bytes"
"encoding/binary"
"reflect" "reflect"
"testing" "testing"
"github.com/anthdm/neo-go/pkg/network/payload"
) )
func TestNewMessage(t *testing.T) { // func TestNewMessage(t *testing.T) {
payload := []byte{} // payload := []byte{}
m := newMessage(ModeTestNet, cmdVersion, payload) // m := newMessage(ModeTestNet, cmdVersion, payload)
if have, want := m.Length, uint32(0); want != have { // if have, want := m.Length, uint32(0); want != have {
t.Errorf("want %d have %d", want, have) // t.Errorf("want %d have %d", want, have)
} // }
if have, want := len(m.Command), 12; want != have { // if have, want := len(m.Command), 12; want != have {
t.Errorf("want %d have %d", want, have) // t.Errorf("want %d have %d", want, have)
} // }
sum := sumSHA256(sumSHA256(payload))[:4] // sum := sumSHA256(sumSHA256(payload))[:4]
sumuint32 := binary.LittleEndian.Uint32(sum) // sumuint32 := binary.LittleEndian.Uint32(sum)
if have, want := m.Checksum, sumuint32; want != have { // if have, want := m.Checksum, sumuint32; want != have {
t.Errorf("want %d have %d", want, have) // t.Errorf("want %d have %d", want, have)
} // }
} // }
func TestMessageEncodeDecode(t *testing.T) { func TestMessageEncodeDecode(t *testing.T) {
m := newMessage(ModeTestNet, cmdVersion, []byte{}) m := newMessage(ModeTestNet, cmdVersion, nil)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if err := m.encode(buf); err != nil { if err := m.encode(buf); err != nil {
@ -48,34 +49,53 @@ func TestMessageEncodeDecode(t *testing.T) {
} }
} }
func TestMessageInvalidChecksum(t *testing.T) { func TestMessageEncodeDecodeWithVersion(t *testing.T) {
m := newMessage(ModeTestNet, cmdVersion, []byte{}) p := payload.NewVersion(2000, "/neo/", 0, true)
m.Checksum = 1337 m := newMessage(ModeTestNet, cmdVersion, p)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if err := m.encode(buf); err != nil { if err := m.encode(buf); err != nil {
t.Error(err) t.Error(err)
} }
t.Log(buf.Len())
md := &Message{} m1 := &Message{}
if err := md.decode(buf); err == nil { if err := m1.decode(buf); err != nil {
t.Error("decode should failed with checkum mismatch error")
}
}
func TestNewVersionPayload(t *testing.T) {
ua := "/neo/0.0.1/"
p := newVersionPayload(3000, ua, 0, true)
b, err := p.encode()
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p1 := m1.Payload.(*payload.Version)
pd := &Version{} t.Log(p1)
if err := pd.decode(b); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(p, pd) {
t.Errorf("both payloads should be equal: %v != %v", p, pd)
}
} }
// func TestMessageInvalidChecksum(t *testing.T) {
// m := newMessage(ModeTestNet, cmdVersion, []byte{})
// m.Checksum = 1337
// buf := &bytes.Buffer{}
// if err := m.encode(buf); err != nil {
// t.Error(err)
// }
// md := &Message{}
// if err := md.decode(buf); err == nil {
// t.Error("decode should failed with checkum mismatch error")
// }
// }
// func TestNewVersionPayload(t *testing.T) {
// ua := "/neo/0.0.1/"
// p := newVersionPayload(3000, ua, 0, true)
// b, err := p.encode()
// if err != nil {
// t.Fatal(err)
// }
// pd := &Version{}
// if err := pd.decode(b); err != nil {
// t.Fatal(err)
// }
// if !reflect.DeepEqual(p, pd) {
// t.Errorf("both payloads should be equal: %v != %v", p, pd)
// }
// }

View file

@ -0,0 +1,37 @@
package payload
import (
"io"
"net"
"unsafe"
)
// AddrWithTime payload
type AddrWithTime struct {
Timestamp uint32
Services uint64
Addr net.Addr
}
func (p *AddrWithTime) Size() uint32 {
return 4 + 8 + uint32(unsafe.Sizeof(p.Addr))
}
func (p *AddrWithTime) Encode(r io.Reader) error {
return nil
}
func (p *AddrWithTime) Decode(w io.Writer) error {
return nil
}
// AddressList is a slice of AddrWithTime.
type AddressList []*AddrWithTime
func (p AddressList) Encode(r io.Reader) error {
return nil
}
func (p AddressList) Decode(w io.Writer) error {
return nil
}

View file

@ -0,0 +1,8 @@
package payload
import (
"testing"
)
func TestNewAddrWithTime(t *testing.T) {
}

View file

@ -0,0 +1,19 @@
package payload
import "io"
// Nothing is a safe non payload.
var Nothing = nothing{}
// Payloader ..
type Payloader interface {
Encode(io.Writer) error
Decode(io.Reader) error
Size() uint32
}
type nothing struct{}
func (p nothing) Encode(w io.Writer) error { return nil }
func (p nothing) Decode(R io.Reader) error { return nil }
func (p nothing) Size() uint32 { return 0 }

View file

@ -0,0 +1,104 @@
package payload
import (
"bytes"
"encoding/binary"
"io"
)
const minVersionSize = 27
// Version payload.
type Version struct {
// currently the version of the protocol is 0
Version uint32
// currently 1
Services uint64
// timestamp
Timestamp uint32
// port this server is listening on
Port uint16
// it's used to distinguish the node from public IP
Nonce uint32
// client id currently 6 bytes \v/NEO:2.6.0/
UserAgent []byte
// Height of the block chain
StartHeight uint32
// Whether to receive and forward
Relay bool
}
// NewVersion returns a pointer to a Version payload.
func NewVersion(p uint16, ua string, h uint32, r bool) *Version {
return &Version{
Version: 0,
Services: 1,
Timestamp: 12345,
Port: p,
Nonce: 19110,
UserAgent: []byte(ua),
StartHeight: 0,
Relay: r,
}
}
// Size ..
func (p *Version) Size() uint32 {
n := minVersionSize + len(p.UserAgent)
return uint32(n)
}
// Decode ..
func (p *Version) Decode(r io.Reader) error {
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(r); err != nil {
return err
}
b := buf.Bytes()
// 27 bytes for the fixed size fields + the length of the user agent
// which is kinda variable, according to the docs.
lenUA := len(b) - minVersionSize
p.Version = binary.LittleEndian.Uint32(b[0:4])
p.Services = binary.LittleEndian.Uint64(b[4:12])
p.Timestamp = binary.LittleEndian.Uint32(b[12:16])
// FIXME: port's byteorder should be big endian according to the docs.
// but when connecting to the privnet docker image it's little endian.
p.Port = binary.LittleEndian.Uint16(b[16:18])
p.Nonce = binary.LittleEndian.Uint32(b[18:22])
p.UserAgent = b[22 : 22+lenUA]
curlen := 22 + lenUA
p.StartHeight = binary.LittleEndian.Uint32(b[curlen : curlen+4])
p.Relay = b[len(b)-1 : len(b)][0] == 1
return nil
}
// Encode ..
func (p *Version) Encode(w io.Writer) error {
buf := make([]byte, p.Size())
binary.LittleEndian.PutUint32(buf[0:4], p.Version)
binary.LittleEndian.PutUint64(buf[4:12], p.Services)
binary.LittleEndian.PutUint32(buf[12:16], p.Timestamp)
// FIXME: byte order (little / big)?
binary.LittleEndian.PutUint16(buf[16:18], p.Port)
binary.LittleEndian.PutUint32(buf[18:22], p.Nonce)
copy(buf[22:22+len(p.UserAgent)], p.UserAgent) //
curLen := 22 + len(p.UserAgent)
binary.LittleEndian.PutUint32(buf[curLen:curLen+4], p.StartHeight)
// yikes
var b []byte
if p.Relay {
b = []byte{1}
} else {
b = []byte{0}
}
copy(buf[curLen+4:len(buf)], b)
_, err := w.Write(buf)
return err
}

View file

@ -0,0 +1,21 @@
package payload
import (
"bytes"
"reflect"
"testing"
)
func TestVersionEncodeDecode(t *testing.T) {
p := NewVersion(3000, "/NEO/", 0, true)
buf := new(bytes.Buffer)
p.Encode(buf)
pd := &Version{}
pd.Decode(buf)
if !reflect.DeepEqual(p, pd) {
t.Fatalf("expect %v to be equal to %v", p, pd)
}
}

View file

@ -7,6 +7,8 @@ import (
"net" "net"
"os" "os"
"strconv" "strconv"
"github.com/anthdm/neo-go/pkg/network/payload"
) )
const ( const (
@ -165,11 +167,11 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error {
switch msg.commandType() { switch msg.commandType() {
case cmdVersion: case cmdVersion:
v, err := msg.decodePayload() // v, err := msg.decodePayload()
if err != nil { // if err != nil {
return err // return err
} // }
return s.handleVersionCmd(v.(*Version), peer) // return s.handleVersionCmd(v.(*Version), peer)
case cmdVerack: case cmdVerack:
case cmdGetAddr: case cmdGetAddr:
return s.handleGetAddrCmd(msg, peer) return s.handleGetAddrCmd(msg, peer)
@ -192,27 +194,18 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error {
// No further communication should been made before both sides has received // No further communication should been made before both sides has received
// the version of eachother. // the version of eachother.
func (s *Server) handlePeerConnected() (*Message, error) { func (s *Server) handlePeerConnected() (*Message, error) {
payload := newVersionPayload(s.port, s.userAgent, 0, s.relay) payload := payload.NewVersion(s.port, s.userAgent, 0, s.relay)
b, err := payload.encode() msg := newMessage(s.net, cmdVersion, payload)
if err != nil {
return nil, err
}
msg := newMessage(s.net, cmdVersion, b)
return msg, nil return msg, nil
} }
// Version declares the server's version. // Version declares the server's version.
func (s *Server) handleVersionCmd(v *Version, peer *Peer) error { func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error {
// TODO: check version and verify to trust that node. // TODO: check version and verify to trust that node.
payload := newVersionPayload(s.port, s.userAgent, 0, s.relay) payload := payload.NewVersion(s.port, s.userAgent, 0, s.relay)
b, err := payload.encode()
if err != nil {
return err
}
// we respond with our version. // we respond with our version.
versionMsg := newMessage(s.net, cmdVersion, b) versionMsg := newMessage(s.net, cmdVersion, payload)
peer.send <- versionMsg peer.send <- versionMsg
// we respond with a verack, we successfully received peer's version // we respond with a verack, we successfully received peer's version