forked from TrueCloudLab/neoneo-go
uint256 + inventoryType
This commit is contained in:
parent
4f6090cebf
commit
f28d8f9ab6
4 changed files with 89 additions and 112 deletions
|
@ -5,7 +5,6 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/anthdm/neo-go/pkg/network/payload"
|
"github.com/anthdm/neo-go/pkg/network/payload"
|
||||||
|
@ -81,15 +80,25 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message {
|
func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message {
|
||||||
var size uint32
|
var (
|
||||||
|
size uint32
|
||||||
|
checksum []byte
|
||||||
|
)
|
||||||
|
|
||||||
if p != nil {
|
if p != nil {
|
||||||
size = p.Size()
|
size = p.Size()
|
||||||
|
b, _ := p.MarshalBinary()
|
||||||
|
checksum = sumSHA256(sumSHA256(b))
|
||||||
|
} else {
|
||||||
|
checksum = sumSHA256(sumSHA256([]byte{}))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Message{
|
return &Message{
|
||||||
Magic: magic,
|
Magic: magic,
|
||||||
Command: cmdToByteSlice(cmd),
|
Command: cmdToByteSlice(cmd),
|
||||||
Length: size,
|
Length: size,
|
||||||
Payload: p,
|
Payload: p,
|
||||||
|
Checksum: binary.LittleEndian.Uint32(checksum[:4]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,40 +146,39 @@ func (m *Message) decode(r io.Reader) error {
|
||||||
m.Length = binary.LittleEndian.Uint32(buf[16:20])
|
m.Length = binary.LittleEndian.Uint32(buf[16:20])
|
||||||
m.Checksum = binary.LittleEndian.Uint32(buf[20:24])
|
m.Checksum = binary.LittleEndian.Uint32(buf[20:24])
|
||||||
|
|
||||||
// if their is no payload.
|
// return if their is no payload.
|
||||||
if m.Length == 0 || !needPayloadDecode(m.commandType()) {
|
if m.Length == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.decodePayload(r)
|
return m.unmarshalPayload(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) decodePayload(r io.Reader) error {
|
func (m *Message) unmarshalPayload(r io.Reader) error {
|
||||||
// write to a buffer what we read to calculate the checksum.
|
pbuf := make([]byte, m.Length)
|
||||||
buffer := new(bytes.Buffer)
|
if _, err := r.Read(pbuf); err != nil {
|
||||||
tr := io.TeeReader(r, buffer)
|
return err
|
||||||
var p payload.Payloader
|
|
||||||
|
|
||||||
switch m.commandType() {
|
|
||||||
case cmdVersion:
|
|
||||||
p = &payload.Version{}
|
|
||||||
if err := p.Decode(tr); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case cmdInv:
|
|
||||||
p = payload.Inventories{}
|
|
||||||
if err := p.Decode(tr); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown command to decode: %s", m.commandType())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compare the checksum of the payload.
|
// Compare the checksum of the payload.
|
||||||
if !compareChecksum(m.Checksum, buffer.Bytes()) {
|
if !compareChecksum(m.Checksum, pbuf) {
|
||||||
return errors.New("checksum mismatch error")
|
return errors.New("checksum mismatch error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var p payload.Payloader
|
||||||
|
switch m.commandType() {
|
||||||
|
case cmdVersion:
|
||||||
|
p = &payload.Version{}
|
||||||
|
if err := p.UnmarshalBinary(pbuf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case cmdInv:
|
||||||
|
p = &payload.Inventory{}
|
||||||
|
if err := p.UnmarshalBinary(pbuf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.Payload = p
|
m.Payload = p
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -178,49 +186,23 @@ func (m *Message) decodePayload(r io.Reader) error {
|
||||||
|
|
||||||
// encode a Message to any given io.Writer.
|
// encode a Message to any given io.Writer.
|
||||||
func (m *Message) encode(w io.Writer) error {
|
func (m *Message) encode(w io.Writer) error {
|
||||||
buf := make([]byte, minMessageSize)
|
buf := make([]byte, minMessageSize+m.Length)
|
||||||
pbuf := new(bytes.Buffer)
|
|
||||||
|
|
||||||
// if there is a payload fill its allocated buffer.
|
|
||||||
var checksum []byte
|
|
||||||
if m.Payload != nil {
|
|
||||||
if err := m.Payload.Encode(pbuf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
checksum = sumSHA256(sumSHA256(pbuf.Bytes()))[:4]
|
|
||||||
} else {
|
|
||||||
checksum = sumSHA256(sumSHA256([]byte{}))[:4]
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Checksum = binary.LittleEndian.Uint32(checksum)
|
|
||||||
|
|
||||||
// fill the message buffer
|
|
||||||
binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic))
|
binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic))
|
||||||
copy(buf[4:16], m.Command)
|
copy(buf[4:16], m.Command)
|
||||||
binary.LittleEndian.PutUint32(buf[16:20], m.Length)
|
binary.LittleEndian.PutUint32(buf[16:20], m.Length)
|
||||||
binary.LittleEndian.PutUint32(buf[20:24], m.Checksum)
|
binary.LittleEndian.PutUint32(buf[20:24], m.Checksum)
|
||||||
|
|
||||||
// write the message
|
if m.Payload != nil {
|
||||||
n, err := w.Write(buf)
|
payload, err := m.Payload.MarshalBinary()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// we need to have at least writen exactly minMessageSize bytes.
|
|
||||||
if n != minMessageSize {
|
|
||||||
return errors.New("long/short read error when encoding message")
|
|
||||||
}
|
|
||||||
|
|
||||||
// write the payload if there was any
|
|
||||||
if pbuf.Len() > 0 {
|
|
||||||
n, err = w.Write(pbuf.Bytes())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
copy(buf[minMessageSize:minMessageSize+m.Length], payload)
|
||||||
|
}
|
||||||
|
|
||||||
if uint32(n) != m.Payload.Size() {
|
if _, err := w.Write(buf); err != nil {
|
||||||
return errors.New("long/short read error when encoding payload")
|
return err
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -243,10 +225,6 @@ func cmdToByteSlice(cmd commandType) []byte {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func needPayloadDecode(cmd commandType) bool {
|
|
||||||
return cmd != cmdVerack && cmd != cmdGetAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func sumSHA256(b []byte) []byte {
|
func sumSHA256(b []byte) []byte {
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
h.Write(b)
|
h.Write(b)
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package payload
|
package payload
|
||||||
|
|
||||||
import "io"
|
import (
|
||||||
|
"encoding"
|
||||||
|
)
|
||||||
|
|
||||||
// Payloader is anything that can be binary encoded and decoded.
|
// Payloader is anything that can be binary marshaled and unmarshaled.
|
||||||
// Every payload used in messages need to satisfy the Payloader interface.
|
// Every payload embbedded in messages need to satisfy the Payloader interface.
|
||||||
type Payloader interface {
|
type Payloader interface {
|
||||||
Encode(io.Writer) error
|
encoding.BinaryMarshaler
|
||||||
Decode(io.Reader) error
|
encoding.BinaryUnmarshaler
|
||||||
Size() uint32
|
Size() uint32
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package payload
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -31,36 +30,21 @@ type Version struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewVersion returns a pointer to a Version payload.
|
// NewVersion returns a pointer to a Version payload.
|
||||||
func NewVersion(p uint16, ua string, h uint32, r bool) *Version {
|
func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version {
|
||||||
return &Version{
|
return &Version{
|
||||||
Version: 0,
|
Version: 0,
|
||||||
Services: 1,
|
Services: 1,
|
||||||
Timestamp: 12345,
|
Timestamp: 12345,
|
||||||
Port: p,
|
Port: p,
|
||||||
Nonce: 19110,
|
Nonce: id,
|
||||||
UserAgent: []byte(ua),
|
UserAgent: []byte(ua),
|
||||||
StartHeight: 0,
|
StartHeight: 0,
|
||||||
Relay: r,
|
Relay: r,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size implements the Payloader interface.
|
// UnmarshalBinary implements the Payloader interface.
|
||||||
func (p *Version) Size() uint32 {
|
func (p *Version) UnmarshalBinary(b []byte) error {
|
||||||
n := minVersionSize
|
|
||||||
return uint32(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode implements the Payloader interface.
|
|
||||||
func (p *Version) Decode(r io.Reader) error {
|
|
||||||
b := make([]byte, minVersionSize)
|
|
||||||
if _, err := r.Read(b); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 27 bytes for the fixed size fields + the length of the user agent
|
|
||||||
// which is kinda variable, according to the docs.
|
|
||||||
lenUA := len(b) - minVersionSize
|
|
||||||
|
|
||||||
p.Version = binary.LittleEndian.Uint32(b[0:4])
|
p.Version = binary.LittleEndian.Uint32(b[0:4])
|
||||||
p.Services = binary.LittleEndian.Uint64(b[4:12])
|
p.Services = binary.LittleEndian.Uint64(b[4:12])
|
||||||
p.Timestamp = binary.LittleEndian.Uint32(b[12:16])
|
p.Timestamp = binary.LittleEndian.Uint32(b[12:16])
|
||||||
|
@ -76,30 +60,33 @@ func (p *Version) Decode(r io.Reader) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode implements the Payloader interface.
|
// MarshalBinary implements the Payloader interface.
|
||||||
func (p *Version) Encode(w io.Writer) error {
|
func (p *Version) MarshalBinary() ([]byte, error) {
|
||||||
buf := make([]byte, p.Size())
|
b := make([]byte, p.Size())
|
||||||
|
|
||||||
binary.LittleEndian.PutUint32(buf[0:4], p.Version)
|
binary.LittleEndian.PutUint32(b[0:4], p.Version)
|
||||||
binary.LittleEndian.PutUint64(buf[4:12], p.Services)
|
binary.LittleEndian.PutUint64(b[4:12], p.Services)
|
||||||
binary.LittleEndian.PutUint32(buf[12:16], p.Timestamp)
|
binary.LittleEndian.PutUint32(b[12:16], p.Timestamp)
|
||||||
// FIXME: byte order (little / big)?
|
// FIXME: byte order (little / big)?
|
||||||
binary.LittleEndian.PutUint16(buf[16:18], p.Port)
|
binary.LittleEndian.PutUint16(b[16:18], p.Port)
|
||||||
binary.LittleEndian.PutUint32(buf[18:22], p.Nonce)
|
binary.LittleEndian.PutUint32(b[18:22], p.Nonce)
|
||||||
copy(buf[22:22+len(p.UserAgent)], p.UserAgent) //
|
copy(b[22:22+len(p.UserAgent)], p.UserAgent) //
|
||||||
curLen := 22 + len(p.UserAgent)
|
curLen := 22 + len(p.UserAgent)
|
||||||
binary.LittleEndian.PutUint32(buf[curLen:curLen+4], p.StartHeight)
|
binary.LittleEndian.PutUint32(b[curLen:curLen+4], p.StartHeight)
|
||||||
|
|
||||||
// yikes
|
// yikes
|
||||||
var b []byte
|
var bln []byte
|
||||||
if p.Relay {
|
if p.Relay {
|
||||||
b = []byte{1}
|
bln = []byte{1}
|
||||||
} else {
|
} else {
|
||||||
b = []byte{0}
|
bln = []byte{0}
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(buf[curLen+4:len(buf)], b)
|
copy(b[curLen+4:len(b)], bln)
|
||||||
|
|
||||||
_, err := w.Write(buf)
|
return b, nil
|
||||||
return err
|
}
|
||||||
|
|
||||||
|
func (p *Version) Size() uint32 {
|
||||||
|
return uint32(minVersionSize + len(p.UserAgent))
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/anthdm/neo-go/pkg/network/payload"
|
"github.com/anthdm/neo-go/pkg/network/payload"
|
||||||
|
"github.com/anthdm/neo-go/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -33,6 +34,9 @@ type messageTuple struct {
|
||||||
type Server struct {
|
type Server struct {
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
|
||||||
|
// id of the server
|
||||||
|
id uint32
|
||||||
|
|
||||||
// the port the TCP listener is listening on.
|
// the port the TCP listener is listening on.
|
||||||
port uint16
|
port uint16
|
||||||
|
|
||||||
|
@ -72,6 +76,7 @@ func NewServer(net NetMode) *Server {
|
||||||
s := &Server{
|
s := &Server{
|
||||||
// It is important to have this user agent correct. Otherwise we will get
|
// It is important to have this user agent correct. Otherwise we will get
|
||||||
// disconnected.
|
// disconnected.
|
||||||
|
id: util.RandUint32(1111111, 9999999),
|
||||||
userAgent: fmt.Sprintf("\v/NEO:%s/", version),
|
userAgent: fmt.Sprintf("\v/NEO:%s/", version),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
peers: make(map[*Peer]bool),
|
peers: make(map[*Peer]bool),
|
||||||
|
@ -95,8 +100,10 @@ func (s *Server) Start(port string, seeds []string) {
|
||||||
s.port = uint16(p)
|
s.port = uint16(p)
|
||||||
|
|
||||||
fmt.Println(logo())
|
fmt.Println(logo())
|
||||||
s.logger.Printf("running %s on %s - TCP %d - relay: %v",
|
fmt.Println(string(s.userAgent))
|
||||||
s.userAgent, s.net, int(s.port), s.relay)
|
fmt.Println("")
|
||||||
|
s.logger.Printf("NET: %s - TCP: %d - RELAY: %v - ID: %d",
|
||||||
|
s.net, int(s.port), s.relay, s.id)
|
||||||
|
|
||||||
go listenTCP(s, port)
|
go listenTCP(s, port)
|
||||||
|
|
||||||
|
@ -163,7 +170,10 @@ func (s *Server) loop() {
|
||||||
// TODO: unregister peers on error.
|
// TODO: unregister peers on error.
|
||||||
// processMessage processes the received message from a remote node.
|
// processMessage processes the received message from a remote node.
|
||||||
func (s *Server) processMessage(msg *Message, peer *Peer) error {
|
func (s *Server) processMessage(msg *Message, peer *Peer) error {
|
||||||
rpcLogger.Printf("IN :: %+v", string(msg.Command))
|
rpcLogger.Printf("IN :: %s", msg.commandType())
|
||||||
|
if msg.Length > 0 {
|
||||||
|
rpcLogger.Printf("IN :: %+v", msg.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
switch msg.commandType() {
|
switch msg.commandType() {
|
||||||
case cmdVersion:
|
case cmdVersion:
|
||||||
|
@ -190,7 +200,7 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error {
|
||||||
// No further communication should been made before both sides has received
|
// No further communication should been made before both sides has received
|
||||||
// the version of eachother.
|
// the version of eachother.
|
||||||
func (s *Server) handlePeerConnected() (*Message, error) {
|
func (s *Server) handlePeerConnected() (*Message, error) {
|
||||||
payload := payload.NewVersion(s.port, s.userAgent, 0, s.relay)
|
payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay)
|
||||||
msg := newMessage(s.net, cmdVersion, payload)
|
msg := newMessage(s.net, cmdVersion, payload)
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
@ -199,7 +209,7 @@ func (s *Server) handlePeerConnected() (*Message, error) {
|
||||||
func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error {
|
func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error {
|
||||||
// TODO: check version and verify to trust that node.
|
// TODO: check version and verify to trust that node.
|
||||||
|
|
||||||
payload := payload.NewVersion(s.port, s.userAgent, 0, s.relay)
|
payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay)
|
||||||
// we respond with our version.
|
// we respond with our version.
|
||||||
versionMsg := newMessage(s.net, cmdVersion, payload)
|
versionMsg := newMessage(s.net, cmdVersion, payload)
|
||||||
peer.send <- versionMsg
|
peer.send <- versionMsg
|
||||||
|
|
Loading…
Reference in a new issue