neo-go/pkg/network/message.go
Roman Khimov e72a8bbecd network: decode CMDGetData the same way as CMDInv
As they're the same wrt the payload.
2019-08-30 13:14:12 +03:00

280 lines
6.1 KiB
Go

package network
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/CityOfZion/neo-go/config"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/crypto/hash"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
)
const (
// The minimum size of a valid message.
minMessageSize = 24
cmdSize = 12
)
var (
errChecksumMismatch = errors.New("checksum mismatch")
)
// Message is the complete message send between nodes.
type Message struct {
// NetMode of the node that sends this message.
Magic config.NetMode
// Command is utf8 code, of which the length is 12 bytes,
// the extra part is filled with 0.
Command [cmdSize]byte
// Length of the payload
Length uint32
// Checksum is the first 4 bytes of the value that two times SHA256
// hash of the payload
Checksum uint32
// Payload send with the message.
Payload payload.Payload
}
// CommandType represents the type of a message command.
type CommandType string
// Valid protocol commands used to send between nodes.
const (
CMDAddr CommandType = "addr"
CMDBlock CommandType = "block"
CMDConsensus CommandType = "consensus"
CMDFilterAdd CommandType = "filteradd"
CMDFilterClear CommandType = "filterclear"
CMDFilterLoad CommandType = "filterload"
CMDGetAddr CommandType = "getaddr"
CMDGetBlocks CommandType = "getblocks"
CMDGetData CommandType = "getdata"
CMDGetHeaders CommandType = "getheaders"
CMDHeaders CommandType = "headers"
CMDInv CommandType = "inv"
CMDMempool CommandType = "mempool"
CMDMerkleBlock CommandType = "merkleblock"
CMDPing CommandType = "ping"
CMDPong CommandType = "pong"
CMDTX CommandType = "tx"
CMDUnknown CommandType = "unknown"
CMDVerack CommandType = "verack"
CMDVersion CommandType = "version"
)
// NewMessage returns a new message with the given payload.
func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Message {
var (
size uint32
checksum []byte
)
if p != nil {
buf := new(bytes.Buffer)
if err := p.EncodeBinary(buf); err != nil {
panic(err)
}
size = uint32(buf.Len())
checksum = hash.Checksum(buf.Bytes())
} else {
checksum = hash.Checksum([]byte{})
}
return &Message{
Magic: magic,
Command: cmdToByteArray(cmd),
Length: size,
Payload: p,
Checksum: binary.LittleEndian.Uint32(checksum[:4]),
}
}
// CommandType converts the 12 byte command slice to a CommandType.
func (m *Message) CommandType() CommandType {
cmd := cmdByteArrayToString(m.Command)
switch cmd {
case "addr":
return CMDAddr
case "block":
return CMDBlock
case "consensus":
return CMDConsensus
case "filteradd":
return CMDFilterAdd
case "filterclear":
return CMDFilterClear
case "filterload":
return CMDFilterLoad
case "getaddr":
return CMDGetAddr
case "getblocks":
return CMDGetBlocks
case "getdata":
return CMDGetData
case "getheaders":
return CMDGetHeaders
case "headers":
return CMDHeaders
case "inv":
return CMDInv
case "mempool":
return CMDMempool
case "merkleblock":
return CMDMerkleBlock
case "ping":
return CMDPing
case "pong":
return CMDPong
case "tx":
return CMDTX
case "verack":
return CMDVerack
case "version":
return CMDVersion
default:
return CMDUnknown
}
}
// Decode a Message from the given reader.
func (m *Message) Decode(r io.Reader) error {
br := util.BinReader{R: r}
br.ReadLE(&m.Magic)
br.ReadLE(&m.Command)
br.ReadLE(&m.Length)
br.ReadLE(&m.Checksum)
if br.Err != nil {
return br.Err
}
// return if their is no payload.
if m.Length == 0 {
return nil
}
return m.decodePayload(r)
}
func (m *Message) decodePayload(r io.Reader) error {
buf := new(bytes.Buffer)
n, err := io.CopyN(buf, r, int64(m.Length))
if err != nil {
return err
}
if uint32(n) != m.Length {
return fmt.Errorf("expected to have read exactly %d bytes got %d", m.Length, n)
}
// Compare the checksum of the payload.
if !compareChecksum(m.Checksum, buf.Bytes()) {
return errChecksumMismatch
}
var p payload.Payload
switch m.CommandType() {
case CMDVersion:
p = &payload.Version{}
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDInv, CMDGetData:
p = &payload.Inventory{}
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDAddr:
p = &payload.AddressList{}
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDBlock:
p = &core.Block{}
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDGetBlocks:
fallthrough
case CMDGetHeaders:
p = &payload.GetBlocks{}
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDHeaders:
p = &payload.Headers{}
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDTX:
p = &transaction.Transaction{}
if err := p.DecodeBinary(buf); err != nil {
return err
}
case CMDMerkleBlock:
p = &payload.MerkleBlock{}
if err := p.DecodeBinary(buf); err != nil {
return err
}
}
m.Payload = p
return nil
}
// Encode a Message to any given io.Writer.
func (m *Message) Encode(w io.Writer) error {
br := util.BinWriter{W: w}
br.WriteLE(m.Magic)
br.WriteLE(m.Command)
br.WriteLE(m.Length)
br.WriteLE(m.Checksum)
if br.Err != nil {
return br.Err
}
if m.Payload != nil {
return m.Payload.EncodeBinary(w)
}
return nil
}
// convert a command (string) to a byte slice filled with 0 bytes till
// size 12.
func cmdToByteArray(cmd CommandType) [cmdSize]byte {
cmdLen := len(cmd)
if cmdLen > cmdSize {
panic("exceeded command max length of size 12")
}
// The command can have max 12 bytes, rest is filled with 0.
b := [cmdSize]byte{}
for i := 0; i < cmdLen; i++ {
b[i] = cmd[i]
}
return b
}
func cmdByteArrayToString(cmd [cmdSize]byte) string {
buf := make([]byte, 0, cmdSize)
for i := 0; i < cmdSize; i++ {
if cmd[i] != 0 {
buf = append(buf, cmd[i])
}
}
return string(buf)
}
func compareChecksum(have uint32, b []byte) bool {
sum := hash.Checksum(b)
want := binary.LittleEndian.Uint32(sum)
return have == want
}