uint256 + inventoryType

This commit is contained in:
anthdm 2018-01-28 08:03:18 +01:00
parent 4f6090cebf
commit f28d8f9ab6
4 changed files with 89 additions and 112 deletions

View file

@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"github.com/anthdm/neo-go/pkg/network/payload" "github.com/anthdm/neo-go/pkg/network/payload"
@ -81,15 +80,25 @@ const (
) )
func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message { func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message {
var size uint32 var (
size uint32
checksum []byte
)
if p != nil { if p != nil {
size = p.Size() size = p.Size()
b, _ := p.MarshalBinary()
checksum = sumSHA256(sumSHA256(b))
} else {
checksum = sumSHA256(sumSHA256([]byte{}))
} }
return &Message{ return &Message{
Magic: magic, Magic: magic,
Command: cmdToByteSlice(cmd), Command: cmdToByteSlice(cmd),
Length: size, Length: size,
Payload: p, Payload: p,
Checksum: binary.LittleEndian.Uint32(checksum[:4]),
} }
} }
@ -137,40 +146,39 @@ func (m *Message) decode(r io.Reader) error {
m.Length = binary.LittleEndian.Uint32(buf[16:20]) m.Length = binary.LittleEndian.Uint32(buf[16:20])
m.Checksum = binary.LittleEndian.Uint32(buf[20:24]) m.Checksum = binary.LittleEndian.Uint32(buf[20:24])
// if their is no payload. // return if their is no payload.
if m.Length == 0 || !needPayloadDecode(m.commandType()) { if m.Length == 0 {
return nil return nil
} }
return m.decodePayload(r) return m.unmarshalPayload(r)
} }
func (m *Message) decodePayload(r io.Reader) error { func (m *Message) unmarshalPayload(r io.Reader) error {
// write to a buffer what we read to calculate the checksum. pbuf := make([]byte, m.Length)
buffer := new(bytes.Buffer) if _, err := r.Read(pbuf); err != nil {
tr := io.TeeReader(r, buffer)
var p payload.Payloader
switch m.commandType() {
case cmdVersion:
p = &payload.Version{}
if err := p.Decode(tr); err != nil {
return err return err
} }
case cmdInv:
p = payload.Inventories{}
if err := p.Decode(tr); err != nil {
return err
}
default:
return fmt.Errorf("unknown command to decode: %s", m.commandType())
}
// Compare the checksum of the payload. // Compare the checksum of the payload.
if !compareChecksum(m.Checksum, buffer.Bytes()) { if !compareChecksum(m.Checksum, pbuf) {
return errors.New("checksum mismatch error") return errors.New("checksum mismatch error")
} }
var p payload.Payloader
switch m.commandType() {
case cmdVersion:
p = &payload.Version{}
if err := p.UnmarshalBinary(pbuf); err != nil {
return err
}
case cmdInv:
p = &payload.Inventory{}
if err := p.UnmarshalBinary(pbuf); err != nil {
return err
}
}
m.Payload = p m.Payload = p
return nil return nil
@ -178,51 +186,25 @@ func (m *Message) decodePayload(r io.Reader) error {
// encode a Message to any given io.Writer. // encode a Message to any given io.Writer.
func (m *Message) encode(w io.Writer) error { func (m *Message) encode(w io.Writer) error {
buf := make([]byte, minMessageSize) buf := make([]byte, minMessageSize+m.Length)
pbuf := new(bytes.Buffer)
// if there is a payload fill its allocated buffer.
var checksum []byte
if m.Payload != nil {
if err := m.Payload.Encode(pbuf); err != nil {
return err
}
checksum = sumSHA256(sumSHA256(pbuf.Bytes()))[:4]
} else {
checksum = sumSHA256(sumSHA256([]byte{}))[:4]
}
m.Checksum = binary.LittleEndian.Uint32(checksum)
// fill the message buffer
binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic)) binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic))
copy(buf[4:16], m.Command) copy(buf[4:16], m.Command)
binary.LittleEndian.PutUint32(buf[16:20], m.Length) binary.LittleEndian.PutUint32(buf[16:20], m.Length)
binary.LittleEndian.PutUint32(buf[20:24], m.Checksum) binary.LittleEndian.PutUint32(buf[20:24], m.Checksum)
// write the message if m.Payload != nil {
n, err := w.Write(buf) payload, err := m.Payload.MarshalBinary()
if err != nil { if err != nil {
return err return err
} }
copy(buf[minMessageSize:minMessageSize+m.Length], payload)
// we need to have at least writen exactly minMessageSize bytes.
if n != minMessageSize {
return errors.New("long/short read error when encoding message")
} }
// write the payload if there was any if _, err := w.Write(buf); err != nil {
if pbuf.Len() > 0 {
n, err = w.Write(pbuf.Bytes())
if err != nil {
return err return err
} }
if uint32(n) != m.Payload.Size() {
return errors.New("long/short read error when encoding payload")
}
}
return nil return nil
} }
@ -243,10 +225,6 @@ func cmdToByteSlice(cmd commandType) []byte {
return b return b
} }
func needPayloadDecode(cmd commandType) bool {
return cmd != cmdVerack && cmd != cmdGetAddr
}
func sumSHA256(b []byte) []byte { func sumSHA256(b []byte) []byte {
h := sha256.New() h := sha256.New()
h.Write(b) h.Write(b)

View file

@ -1,11 +1,13 @@
package payload package payload
import "io" import (
"encoding"
)
// Payloader is anything that can be binary encoded and decoded. // Payloader is anything that can be binary marshaled and unmarshaled.
// Every payload used in messages need to satisfy the Payloader interface. // Every payload embbedded in messages need to satisfy the Payloader interface.
type Payloader interface { type Payloader interface {
Encode(io.Writer) error encoding.BinaryMarshaler
Decode(io.Reader) error encoding.BinaryUnmarshaler
Size() uint32 Size() uint32
} }

View file

@ -2,7 +2,6 @@ package payload
import ( import (
"encoding/binary" "encoding/binary"
"io"
) )
const ( const (
@ -31,36 +30,21 @@ type Version struct {
} }
// NewVersion returns a pointer to a Version payload. // NewVersion returns a pointer to a Version payload.
func NewVersion(p uint16, ua string, h uint32, r bool) *Version { func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version {
return &Version{ return &Version{
Version: 0, Version: 0,
Services: 1, Services: 1,
Timestamp: 12345, Timestamp: 12345,
Port: p, Port: p,
Nonce: 19110, Nonce: id,
UserAgent: []byte(ua), UserAgent: []byte(ua),
StartHeight: 0, StartHeight: 0,
Relay: r, Relay: r,
} }
} }
// Size implements the Payloader interface. // UnmarshalBinary implements the Payloader interface.
func (p *Version) Size() uint32 { func (p *Version) UnmarshalBinary(b []byte) error {
n := minVersionSize
return uint32(n)
}
// Decode implements the Payloader interface.
func (p *Version) Decode(r io.Reader) error {
b := make([]byte, minVersionSize)
if _, err := r.Read(b); err != nil {
return err
}
// 27 bytes for the fixed size fields + the length of the user agent
// which is kinda variable, according to the docs.
lenUA := len(b) - minVersionSize
p.Version = binary.LittleEndian.Uint32(b[0:4]) p.Version = binary.LittleEndian.Uint32(b[0:4])
p.Services = binary.LittleEndian.Uint64(b[4:12]) p.Services = binary.LittleEndian.Uint64(b[4:12])
p.Timestamp = binary.LittleEndian.Uint32(b[12:16]) p.Timestamp = binary.LittleEndian.Uint32(b[12:16])
@ -76,30 +60,33 @@ func (p *Version) Decode(r io.Reader) error {
return nil return nil
} }
// Encode implements the Payloader interface. // MarshalBinary implements the Payloader interface.
func (p *Version) Encode(w io.Writer) error { func (p *Version) MarshalBinary() ([]byte, error) {
buf := make([]byte, p.Size()) b := make([]byte, p.Size())
binary.LittleEndian.PutUint32(buf[0:4], p.Version) binary.LittleEndian.PutUint32(b[0:4], p.Version)
binary.LittleEndian.PutUint64(buf[4:12], p.Services) binary.LittleEndian.PutUint64(b[4:12], p.Services)
binary.LittleEndian.PutUint32(buf[12:16], p.Timestamp) binary.LittleEndian.PutUint32(b[12:16], p.Timestamp)
// FIXME: byte order (little / big)? // FIXME: byte order (little / big)?
binary.LittleEndian.PutUint16(buf[16:18], p.Port) binary.LittleEndian.PutUint16(b[16:18], p.Port)
binary.LittleEndian.PutUint32(buf[18:22], p.Nonce) binary.LittleEndian.PutUint32(b[18:22], p.Nonce)
copy(buf[22:22+len(p.UserAgent)], p.UserAgent) // copy(b[22:22+len(p.UserAgent)], p.UserAgent) //
curLen := 22 + len(p.UserAgent) curLen := 22 + len(p.UserAgent)
binary.LittleEndian.PutUint32(buf[curLen:curLen+4], p.StartHeight) binary.LittleEndian.PutUint32(b[curLen:curLen+4], p.StartHeight)
// yikes // yikes
var b []byte var bln []byte
if p.Relay { if p.Relay {
b = []byte{1} bln = []byte{1}
} else { } else {
b = []byte{0} bln = []byte{0}
} }
copy(buf[curLen+4:len(buf)], b) copy(b[curLen+4:len(b)], bln)
_, err := w.Write(buf) return b, nil
return err }
func (p *Version) Size() uint32 {
return uint32(minVersionSize + len(p.UserAgent))
} }

View file

@ -9,6 +9,7 @@ import (
"strconv" "strconv"
"github.com/anthdm/neo-go/pkg/network/payload" "github.com/anthdm/neo-go/pkg/network/payload"
"github.com/anthdm/neo-go/pkg/util"
) )
const ( const (
@ -33,6 +34,9 @@ type messageTuple struct {
type Server struct { type Server struct {
logger *log.Logger logger *log.Logger
// id of the server
id uint32
// the port the TCP listener is listening on. // the port the TCP listener is listening on.
port uint16 port uint16
@ -72,6 +76,7 @@ func NewServer(net NetMode) *Server {
s := &Server{ s := &Server{
// It is important to have this user agent correct. Otherwise we will get // It is important to have this user agent correct. Otherwise we will get
// disconnected. // disconnected.
id: util.RandUint32(1111111, 9999999),
userAgent: fmt.Sprintf("\v/NEO:%s/", version), userAgent: fmt.Sprintf("\v/NEO:%s/", version),
logger: logger, logger: logger,
peers: make(map[*Peer]bool), peers: make(map[*Peer]bool),
@ -95,8 +100,10 @@ func (s *Server) Start(port string, seeds []string) {
s.port = uint16(p) s.port = uint16(p)
fmt.Println(logo()) fmt.Println(logo())
s.logger.Printf("running %s on %s - TCP %d - relay: %v", fmt.Println(string(s.userAgent))
s.userAgent, s.net, int(s.port), s.relay) fmt.Println("")
s.logger.Printf("NET: %s - TCP: %d - RELAY: %v - ID: %d",
s.net, int(s.port), s.relay, s.id)
go listenTCP(s, port) go listenTCP(s, port)
@ -163,7 +170,10 @@ func (s *Server) loop() {
// TODO: unregister peers on error. // TODO: unregister peers on error.
// processMessage processes the received message from a remote node. // processMessage processes the received message from a remote node.
func (s *Server) processMessage(msg *Message, peer *Peer) error { func (s *Server) processMessage(msg *Message, peer *Peer) error {
rpcLogger.Printf("IN :: %+v", string(msg.Command)) rpcLogger.Printf("IN :: %s", msg.commandType())
if msg.Length > 0 {
rpcLogger.Printf("IN :: %+v", msg.Payload)
}
switch msg.commandType() { switch msg.commandType() {
case cmdVersion: case cmdVersion:
@ -190,7 +200,7 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error {
// No further communication should been made before both sides has received // No further communication should been made before both sides has received
// the version of eachother. // the version of eachother.
func (s *Server) handlePeerConnected() (*Message, error) { func (s *Server) handlePeerConnected() (*Message, error) {
payload := payload.NewVersion(s.port, s.userAgent, 0, s.relay) payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay)
msg := newMessage(s.net, cmdVersion, payload) msg := newMessage(s.net, cmdVersion, payload)
return msg, nil return msg, nil
} }
@ -199,7 +209,7 @@ func (s *Server) handlePeerConnected() (*Message, error) {
func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error { func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error {
// TODO: check version and verify to trust that node. // TODO: check version and verify to trust that node.
payload := payload.NewVersion(s.port, s.userAgent, 0, s.relay) payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay)
// we respond with our version. // we respond with our version.
versionMsg := newMessage(s.net, cmdVersion, payload) versionMsg := newMessage(s.net, cmdVersion, payload)
peer.send <- versionMsg peer.send <- versionMsg