wip implement inv command
This commit is contained in:
parent
754a473488
commit
d7826a4d43
7 changed files with 51 additions and 88 deletions
|
@ -1,25 +0,0 @@
|
||||||
package network
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
// AddrWithTimestamp payload.
|
|
||||||
type AddrWithTimestamp struct {
|
|
||||||
t uint32
|
|
||||||
services uint64
|
|
||||||
endpoint net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAddrWithTimestampFromPeer(p *Peer) AddrWithTimestamp {
|
|
||||||
return AddrWithTimestamp{
|
|
||||||
t: 1223345,
|
|
||||||
services: 1,
|
|
||||||
endpoint: p.conn.RemoteAddr(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddrPayload container a list of known peer addresses.
|
|
||||||
type AddrPayload []AddrWithTimestamp
|
|
||||||
|
|
||||||
func (p AddrPayload) encode() ([]byte, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/anthdm/neo-go/pkg/network/payload"
|
"github.com/anthdm/neo-go/pkg/network/payload"
|
||||||
|
@ -92,23 +93,6 @@ func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TeeWriter(w io.Writer, r io.Reader) io.Writer {
|
|
||||||
return &teeWriter{w, r}
|
|
||||||
}
|
|
||||||
|
|
||||||
type teeWriter struct {
|
|
||||||
w io.Writer
|
|
||||||
r io.Reader
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *teeWriter) Write(b []byte) (n int, err error) {
|
|
||||||
n, err = w.w.Write(b)
|
|
||||||
if n > 0 {
|
|
||||||
n, err = w.r.Read(b[:n])
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts the 12 byte command slice to a commandType.
|
// Converts the 12 byte command slice to a commandType.
|
||||||
func (m *Message) commandType() commandType {
|
func (m *Message) commandType() commandType {
|
||||||
cmd := string(bytes.TrimRight(m.Command, "\x00"))
|
cmd := string(bytes.TrimRight(m.Command, "\x00"))
|
||||||
|
@ -153,11 +137,16 @@ func (m *Message) decode(r io.Reader) error {
|
||||||
m.Length = binary.LittleEndian.Uint32(buf[16:20])
|
m.Length = binary.LittleEndian.Uint32(buf[16:20])
|
||||||
m.Checksum = binary.LittleEndian.Uint32(buf[20:24])
|
m.Checksum = binary.LittleEndian.Uint32(buf[20:24])
|
||||||
|
|
||||||
// payload is 0, so dont decode it.
|
// if their is no payload.
|
||||||
if m.Length == 0 {
|
if m.Length == 0 || !needPayloadDecode(m.commandType()) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return m.decodePayload(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) decodePayload(r io.Reader) error {
|
||||||
|
// write to a buffer what we read to calculate the checksum.
|
||||||
buffer := new(bytes.Buffer)
|
buffer := new(bytes.Buffer)
|
||||||
tr := io.TeeReader(r, buffer)
|
tr := io.TeeReader(r, buffer)
|
||||||
var p payload.Payloader
|
var p payload.Payloader
|
||||||
|
@ -168,6 +157,8 @@ func (m *Message) decode(r io.Reader) error {
|
||||||
if err := p.Decode(tr); err != nil {
|
if err := p.Decode(tr); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown command to decode: %s", m.commandType())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compare the checksum of the payload.
|
// Compare the checksum of the payload.
|
||||||
|
@ -186,14 +177,18 @@ func (m *Message) encode(w io.Writer) error {
|
||||||
pbuf := new(bytes.Buffer)
|
pbuf := new(bytes.Buffer)
|
||||||
|
|
||||||
// if there is a payload fill its allocated buffer.
|
// if there is a payload fill its allocated buffer.
|
||||||
|
var checksum []byte
|
||||||
if m.Payload != nil {
|
if m.Payload != nil {
|
||||||
if err := m.Payload.Encode(pbuf); err != nil {
|
if err := m.Payload.Encode(pbuf); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
checksum := sumSHA256(sumSHA256(pbuf.Bytes()))[:4]
|
checksum = sumSHA256(sumSHA256(pbuf.Bytes()))[:4]
|
||||||
m.Checksum = binary.LittleEndian.Uint32(checksum)
|
} else {
|
||||||
|
checksum = sumSHA256(sumSHA256([]byte{}))[:4]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.Checksum = binary.LittleEndian.Uint32(checksum)
|
||||||
|
|
||||||
// fill the message buffer
|
// fill the message buffer
|
||||||
binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic))
|
binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic))
|
||||||
copy(buf[4:16], m.Command)
|
copy(buf[4:16], m.Command)
|
||||||
|
@ -243,6 +238,10 @@ func cmdToByteSlice(cmd commandType) []byte {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func needPayloadDecode(cmd commandType) bool {
|
||||||
|
return cmd != cmdVerack && cmd != cmdGetAddr
|
||||||
|
}
|
||||||
|
|
||||||
func sumSHA256(b []byte) []byte {
|
func sumSHA256(b []byte) []byte {
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
h.Write(b)
|
h.Write(b)
|
||||||
|
|
|
@ -68,20 +68,20 @@ func TestMessageEncodeDecodeWithVersion(t *testing.T) {
|
||||||
t.Log(p1)
|
t.Log(p1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// func TestMessageInvalidChecksum(t *testing.T) {
|
func TestMessageInvalidChecksum(t *testing.T) {
|
||||||
// m := newMessage(ModeTestNet, cmdVersion, []byte{})
|
m := newMessage(ModeTestNet, cmdVersion, nil)
|
||||||
// m.Checksum = 1337
|
m.Checksum = 1337
|
||||||
|
|
||||||
// buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
// if err := m.encode(buf); err != nil {
|
if err := m.encode(buf); err != nil {
|
||||||
// t.Error(err)
|
t.Error(err)
|
||||||
// }
|
}
|
||||||
|
|
||||||
// md := &Message{}
|
md := &Message{}
|
||||||
// if err := md.decode(buf); err == nil {
|
if err := md.decode(buf); err == nil {
|
||||||
// t.Error("decode should failed with checkum mismatch error")
|
t.Error("decode should failed with checkum mismatch error")
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
|
||||||
// func TestNewVersionPayload(t *testing.T) {
|
// func TestNewVersionPayload(t *testing.T) {
|
||||||
// ua := "/neo/0.0.1/"
|
// ua := "/neo/0.0.1/"
|
||||||
|
|
|
@ -2,18 +2,10 @@ package payload
|
||||||
|
|
||||||
import "io"
|
import "io"
|
||||||
|
|
||||||
// Nothing is a safe non payload.
|
// Payloader is anything that can be binary encoded and decoded.
|
||||||
var Nothing = nothing{}
|
// Every payload used in messages need to satisfy the Payloader interface.
|
||||||
|
|
||||||
// Payloader ..
|
|
||||||
type Payloader interface {
|
type Payloader interface {
|
||||||
Encode(io.Writer) error
|
Encode(io.Writer) error
|
||||||
Decode(io.Reader) error
|
Decode(io.Reader) error
|
||||||
Size() uint32
|
Size() uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type nothing struct{}
|
|
||||||
|
|
||||||
func (p nothing) Encode(w io.Writer) error { return nil }
|
|
||||||
func (p nothing) Decode(R io.Reader) error { return nil }
|
|
||||||
func (p nothing) Size() uint32 { return 0 }
|
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
package payload
|
package payload
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
const minVersionSize = 27
|
const (
|
||||||
|
lenUA = 12
|
||||||
|
minVersionSize = 27 + lenUA
|
||||||
|
)
|
||||||
|
|
||||||
// Version payload.
|
// Version payload.
|
||||||
type Version struct {
|
type Version struct {
|
||||||
|
@ -20,7 +22,7 @@ type Version struct {
|
||||||
Port uint16
|
Port uint16
|
||||||
// it's used to distinguish the node from public IP
|
// it's used to distinguish the node from public IP
|
||||||
Nonce uint32
|
Nonce uint32
|
||||||
// client id currently 6 bytes \v/NEO:2.6.0/
|
// client id currently 12 bytes \v/NEO:2.6.0/
|
||||||
UserAgent []byte
|
UserAgent []byte
|
||||||
// Height of the block chain
|
// Height of the block chain
|
||||||
StartHeight uint32
|
StartHeight uint32
|
||||||
|
@ -42,20 +44,19 @@ func NewVersion(p uint16, ua string, h uint32, r bool) *Version {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size ..
|
// Size implements the Payloader interface.
|
||||||
func (p *Version) Size() uint32 {
|
func (p *Version) Size() uint32 {
|
||||||
n := minVersionSize + len(p.UserAgent)
|
n := minVersionSize
|
||||||
return uint32(n)
|
return uint32(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode ..
|
// Decode implements the Payloader interface.
|
||||||
func (p *Version) Decode(r io.Reader) error {
|
func (p *Version) Decode(r io.Reader) error {
|
||||||
buf := new(bytes.Buffer)
|
b := make([]byte, minVersionSize)
|
||||||
if _, err := buf.ReadFrom(r); err != nil {
|
if _, err := r.Read(b); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
b := buf.Bytes()
|
|
||||||
// 27 bytes for the fixed size fields + the length of the user agent
|
// 27 bytes for the fixed size fields + the length of the user agent
|
||||||
// which is kinda variable, according to the docs.
|
// which is kinda variable, according to the docs.
|
||||||
lenUA := len(b) - minVersionSize
|
lenUA := len(b) - minVersionSize
|
||||||
|
@ -75,7 +76,7 @@ func (p *Version) Decode(r io.Reader) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode ..
|
// Encode implements the Payloader interface.
|
||||||
func (p *Version) Encode(w io.Writer) error {
|
func (p *Version) Encode(w io.Writer) error {
|
||||||
buf := make([]byte, p.Size())
|
buf := make([]byte, p.Size())
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ func (p *Peer) writeLoop() {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
msg := <-p.send
|
msg := <-p.send
|
||||||
rpcLogger.Printf("OUT :: %+v", msg)
|
rpcLogger.Printf("OUT :: %s", msg.commandType())
|
||||||
if err := msg.encode(p.conn); err != nil {
|
if err := msg.encode(p.conn); err != nil {
|
||||||
log.Printf("encode error: %s", err)
|
log.Printf("encode error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -167,11 +167,7 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error {
|
||||||
|
|
||||||
switch msg.commandType() {
|
switch msg.commandType() {
|
||||||
case cmdVersion:
|
case cmdVersion:
|
||||||
// v, err := msg.decodePayload()
|
return s.handleVersionCmd(msg.Payload.(*payload.Version), peer)
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// return s.handleVersionCmd(v.(*Version), peer)
|
|
||||||
case cmdVerack:
|
case cmdVerack:
|
||||||
case cmdGetAddr:
|
case cmdGetAddr:
|
||||||
return s.handleGetAddrCmd(msg, peer)
|
return s.handleGetAddrCmd(msg, peer)
|
||||||
|
@ -226,10 +222,10 @@ func (s *Server) handleGetAddrCmd(msg *Message, peer *Peer) error {
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// return err
|
// return err
|
||||||
// }
|
// }
|
||||||
var addrList []AddrWithTimestamp
|
// var addrList []AddrWithTimestamp
|
||||||
for peer := range s.peers {
|
// for peer := range s.peers {
|
||||||
addrList = append(addrList, newAddrWithTimestampFromPeer(peer))
|
// addrList = append(addrList, newAddrWithTimestampFromPeer(peer))
|
||||||
}
|
// }
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue