forked from TrueCloudLab/neoneo-go
256 lines
5.5 KiB
Go
256 lines
5.5 KiB
Go
package network
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
|
|
"github.com/anthdm/neo-go/pkg/network/payload"
|
|
)
|
|
|
|
const (
|
|
// The minimum size of a valid message.
|
|
minMessageSize = 24
|
|
)
|
|
|
|
// NetMode type that is compatible with netModes below.
|
|
type NetMode uint32
|
|
|
|
// String implements the stringer interface.
|
|
func (n NetMode) String() string {
|
|
switch n {
|
|
case ModeDevNet:
|
|
return "devnet"
|
|
case ModeTestNet:
|
|
return "testnet"
|
|
case ModeMainNet:
|
|
return "mainnet"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// Values used for the magic field, according to the docs.
|
|
const (
|
|
ModeMainNet NetMode = 0x00746e41 // 7630401
|
|
ModeTestNet = 0x74746e41 // 1953787457
|
|
ModeDevNet = 56753 // docker privnet
|
|
)
|
|
|
|
// Message is the complete message send between nodes.
|
|
//
|
|
// Size Field DataType Description
|
|
// ------------------------------------------------------
|
|
// 4 Magic uint32 Protocol ID
|
|
// 12 Command char[12] Command
|
|
// 4 length uint32 Length of payload
|
|
// 4 Checksum uint32 Checksum
|
|
// length Payload uint8[length] Content of message
|
|
type Message struct {
|
|
Magic NetMode
|
|
// Command is utf8 code, of which the length is 12 bytes,
|
|
// the extra part is filled with 0.
|
|
Command []byte
|
|
// Length of the payload
|
|
Length uint32
|
|
// Checksum is the first 4 bytes of the value that two times SHA256
|
|
// hash of the payload
|
|
Checksum uint32
|
|
// Payload send with the message.
|
|
Payload payload.Payloader
|
|
}
|
|
|
|
type commandType string
|
|
|
|
// valid commands used to send between nodes.
|
|
const (
|
|
cmdVersion commandType = "version"
|
|
cmdVerack = "verack"
|
|
cmdGetAddr = "getaddr"
|
|
cmdAddr = "addr"
|
|
cmdGetHeaders = "getheaders"
|
|
cmdHeaders = "headers"
|
|
cmdGetBlocks = "getblocks"
|
|
cmdInv = "inv"
|
|
cmdGetData = "getdata"
|
|
cmdBlock = "block"
|
|
cmdTX = "tx"
|
|
)
|
|
|
|
func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message {
|
|
var size uint32
|
|
if p != nil {
|
|
size = p.Size()
|
|
}
|
|
return &Message{
|
|
Magic: magic,
|
|
Command: cmdToByteSlice(cmd),
|
|
Length: size,
|
|
Payload: p,
|
|
}
|
|
}
|
|
|
|
func TeeWriter(w io.Writer, r io.Reader) io.Writer {
|
|
return &teeWriter{w, r}
|
|
}
|
|
|
|
type teeWriter struct {
|
|
w io.Writer
|
|
r io.Reader
|
|
}
|
|
|
|
func (w *teeWriter) Write(b []byte) (n int, err error) {
|
|
n, err = w.w.Write(b)
|
|
if n > 0 {
|
|
n, err = w.r.Read(b[:n])
|
|
}
|
|
return
|
|
}
|
|
|
|
// Converts the 12 byte command slice to a commandType.
|
|
func (m *Message) commandType() commandType {
|
|
cmd := string(bytes.TrimRight(m.Command, "\x00"))
|
|
switch cmd {
|
|
case "version":
|
|
return cmdVersion
|
|
case "verack":
|
|
return cmdVerack
|
|
case "getaddr":
|
|
return cmdGetAddr
|
|
case "addr":
|
|
return cmdAddr
|
|
case "getheaders":
|
|
return cmdGetHeaders
|
|
case "header":
|
|
return cmdHeaders
|
|
case "getblocks":
|
|
return cmdGetBlocks
|
|
case "inv":
|
|
return cmdInv
|
|
case "getdata":
|
|
return cmdGetData
|
|
case "block":
|
|
return cmdBlock
|
|
case "tx":
|
|
return cmdTX
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// decode a Message from the given reader.
|
|
func (m *Message) decode(r io.Reader) error {
|
|
// 24 bytes for the fixed sized fields.
|
|
buf := make([]byte, minMessageSize)
|
|
if _, err := r.Read(buf); err != nil {
|
|
return err
|
|
}
|
|
|
|
m.Magic = NetMode(binary.LittleEndian.Uint32(buf[0:4]))
|
|
m.Command = buf[4:16]
|
|
m.Length = binary.LittleEndian.Uint32(buf[16:20])
|
|
m.Checksum = binary.LittleEndian.Uint32(buf[20:24])
|
|
|
|
// payload is 0, so dont decode it.
|
|
if m.Length == 0 {
|
|
return nil
|
|
}
|
|
|
|
buffer := new(bytes.Buffer)
|
|
tr := io.TeeReader(r, buffer)
|
|
var p payload.Payloader
|
|
|
|
switch m.commandType() {
|
|
case cmdVersion:
|
|
p = &payload.Version{}
|
|
if err := p.Decode(tr); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Compare the checksum of the payload.
|
|
if !compareChecksum(m.Checksum, buffer.Bytes()) {
|
|
return errors.New("checksum mismatch error")
|
|
}
|
|
|
|
m.Payload = p
|
|
|
|
return nil
|
|
}
|
|
|
|
// encode a Message to any given io.Writer.
|
|
func (m *Message) encode(w io.Writer) error {
|
|
buf := make([]byte, minMessageSize)
|
|
pbuf := new(bytes.Buffer)
|
|
|
|
// if there is a payload fill its allocated buffer.
|
|
if m.Payload != nil {
|
|
if err := m.Payload.Encode(pbuf); err != nil {
|
|
return err
|
|
}
|
|
checksum := sumSHA256(sumSHA256(pbuf.Bytes()))[:4]
|
|
m.Checksum = binary.LittleEndian.Uint32(checksum)
|
|
}
|
|
|
|
// fill the message buffer
|
|
binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic))
|
|
copy(buf[4:16], m.Command)
|
|
binary.LittleEndian.PutUint32(buf[16:20], m.Length)
|
|
binary.LittleEndian.PutUint32(buf[20:24], m.Checksum)
|
|
|
|
// write the message
|
|
n, err := w.Write(buf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// we need to have at least writen exactly minMessageSize bytes.
|
|
if n != minMessageSize {
|
|
return errors.New("long/short read error when encoding message")
|
|
}
|
|
|
|
// write the payload if there was any
|
|
if pbuf.Len() > 0 {
|
|
n, err = w.Write(pbuf.Bytes())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if uint32(n) != m.Payload.Size() {
|
|
return errors.New("long/short read error when encoding payload")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// convert a command (string) to a byte slice filled with 0 bytes till
|
|
// size 12.
|
|
func cmdToByteSlice(cmd commandType) []byte {
|
|
cmdLen := len(cmd)
|
|
if cmdLen > 12 {
|
|
panic("exceeded command max length of size 12")
|
|
}
|
|
|
|
// The command can have max 12 bytes, rest is filled with 0.
|
|
b := []byte(cmd)
|
|
for i := 0; i < 12-cmdLen; i++ {
|
|
b = append(b, byte('\x00'))
|
|
}
|
|
|
|
return b
|
|
}
|
|
|
|
func sumSHA256(b []byte) []byte {
|
|
h := sha256.New()
|
|
h.Write(b)
|
|
return h.Sum(nil)
|
|
}
|
|
|
|
func compareChecksum(have uint32, b []byte) bool {
|
|
sum := sumSHA256(sumSHA256(b))[:4]
|
|
want := binary.LittleEndian.Uint32(sum)
|
|
return have == want
|
|
}
|