protocol: switch to binary MessageCommand

closes #888
This commit is contained in:
Anna Shaleva 2020-05-19 14:54:51 +03:00
parent 1317666167
commit 3bcc56bdcf
5 changed files with 140 additions and 117 deletions

View file

@ -11,20 +11,15 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/payload"
)
const (
// The minimum size of a valid message.
minMessageSize = 24
cmdSize = 12
)
//go:generate stringer -type=CommandType
// Message is the complete message send between nodes.
type Message struct {
// NetMode of the node that sends this message.
Magic config.NetMode
// Command is utf8 code, of which the length is 12 bytes,
// the extra part is filled with 0.
Command [cmdSize]byte
// Command is byte command code.
Command CommandType
// Length of the payload.
Length uint32
@ -34,30 +29,41 @@ type Message struct {
}
// CommandType represents the type of a message command.
type CommandType string
type CommandType byte
// Valid protocol commands used to send between nodes.
const (
CMDAddr CommandType = "addr"
CMDBlock CommandType = "block"
CMDConsensus CommandType = "consensus"
CMDFilterAdd CommandType = "filteradd"
CMDFilterClear CommandType = "filterclear"
CMDFilterLoad CommandType = "filterload"
CMDGetAddr CommandType = "getaddr"
CMDGetBlocks CommandType = "getblocks"
CMDGetData CommandType = "getdata"
CMDGetHeaders CommandType = "getheaders"
CMDHeaders CommandType = "headers"
CMDInv CommandType = "inv"
CMDMempool CommandType = "mempool"
CMDMerkleBlock CommandType = "merkleblock"
CMDPing CommandType = "ping"
CMDPong CommandType = "pong"
CMDTX CommandType = "tx"
CMDUnknown CommandType = "unknown"
CMDVerack CommandType = "verack"
CMDVersion CommandType = "version"
// handshaking
CMDVersion CommandType = 0x00
CMDVerack CommandType = 0x01
// connectivity
CMDGetAddr CommandType = 0x10
CMDAddr CommandType = 0x11
CMDPing CommandType = 0x18
CMDPong CommandType = 0x19
// synchronization
CMDGetHeaders CommandType = 0x20
CMDHeaders CommandType = 0x21
CMDGetBlocks CommandType = 0x24
CMDMempool CommandType = 0x25
CMDInv CommandType = 0x27
CMDGetData CommandType = 0x28
CMDUnknown CommandType = 0x2a
CMDTX CommandType = 0x2b
CMDBlock CommandType = 0x2c
CMDConsensus CommandType = 0x2d
CMDReject CommandType = 0x2f
// SPV protocol
CMDFilterLoad CommandType = 0x30
CMDFilterAdd CommandType = 0x31
CMDFilterClear CommandType = 0x32
CMDMerkleBlock CommandType = 0x38
// others
CMDAlert CommandType = 0x40
)
// NewMessage returns a new message with the given payload.
@ -78,63 +84,16 @@ func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Messa
return &Message{
Magic: magic,
Command: cmdToByteArray(cmd),
Command: cmd,
Length: size,
Payload: p,
}
}
// CommandType converts the 12 byte command slice to a CommandType.
func (m *Message) CommandType() CommandType {
cmd := cmdByteArrayToString(m.Command)
switch cmd {
case "addr":
return CMDAddr
case "block":
return CMDBlock
case "consensus":
return CMDConsensus
case "filteradd":
return CMDFilterAdd
case "filterclear":
return CMDFilterClear
case "filterload":
return CMDFilterLoad
case "getaddr":
return CMDGetAddr
case "getblocks":
return CMDGetBlocks
case "getdata":
return CMDGetData
case "getheaders":
return CMDGetHeaders
case "headers":
return CMDHeaders
case "inv":
return CMDInv
case "mempool":
return CMDMempool
case "merkleblock":
return CMDMerkleBlock
case "ping":
return CMDPing
case "pong":
return CMDPong
case "tx":
return CMDTX
case "verack":
return CMDVerack
case "version":
return CMDVersion
default:
return CMDUnknown
}
}
// Decode decodes a Message from the given reader.
func (m *Message) Decode(br *io.BinReader) error {
m.Magic = config.NetMode(br.ReadU32LE())
br.ReadBytes(m.Command[:])
m.Command = CommandType(br.ReadB())
m.Length = br.ReadU32LE()
if br.Err != nil {
return br.Err
@ -155,7 +114,7 @@ func (m *Message) decodePayload(br *io.BinReader) error {
r := io.NewBinReaderFromBuf(buf)
var p payload.Payload
switch m.CommandType() {
switch m.Command {
case CMDVersion:
p = &payload.Version{}
case CMDInv, CMDGetData:
@ -179,7 +138,7 @@ func (m *Message) decodePayload(br *io.BinReader) error {
case CMDPing, CMDPong:
p = &payload.Ping{}
default:
return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command))
return fmt.Errorf("can't decode command %s", m.Command.String())
}
p.DecodeBinary(r)
if r.Err == nil || r.Err == payload.ErrTooManyHeaders {
@ -192,7 +151,7 @@ func (m *Message) decodePayload(br *io.BinReader) error {
// Encode encodes a Message to any given BinWriter.
func (m *Message) Encode(br *io.BinWriter) error {
br.WriteU32LE(uint32(m.Magic))
br.WriteBytes(m.Command[:])
br.WriteB(byte(m.Command))
br.WriteU32LE(m.Length)
if m.Payload != nil {
m.Payload.EncodeBinary(br)
@ -215,30 +174,3 @@ func (m *Message) Bytes() ([]byte, error) {
}
return w.Bytes(), nil
}
// convert a command (string) to a byte slice filled with 0 bytes till
// size 12.
func cmdToByteArray(cmd CommandType) [cmdSize]byte {
cmdLen := len(cmd)
if cmdLen > cmdSize {
panic("exceeded command max length of size 12")
}
// The command can have max 12 bytes, rest is filled with 0.
b := [cmdSize]byte{}
for i := 0; i < cmdLen; i++ {
b[i] = cmd[i]
}
return b
}
func cmdByteArrayToString(cmd [cmdSize]byte) string {
buf := make([]byte, 0, cmdSize)
for i := 0; i < cmdSize; i++ {
if cmd[i] != 0 {
buf = append(buf, cmd[i])
}
}
return string(buf)
}

View file

@ -0,0 +1,91 @@
// Code generated by "stringer -type=CommandType"; DO NOT EDIT.
package network
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[CMDVersion-0]
_ = x[CMDVerack-1]
_ = x[CMDGetAddr-16]
_ = x[CMDAddr-17]
_ = x[CMDPing-24]
_ = x[CMDPong-25]
_ = x[CMDGetHeaders-32]
_ = x[CMDHeaders-33]
_ = x[CMDGetBlocks-36]
_ = x[CMDMempool-37]
_ = x[CMDInv-39]
_ = x[CMDGetData-40]
_ = x[CMDUnknown-42]
_ = x[CMDTX-43]
_ = x[CMDBlock-44]
_ = x[CMDConsensus-45]
_ = x[CMDReject-47]
_ = x[CMDFilterLoad-48]
_ = x[CMDFilterAdd-49]
_ = x[CMDFilterClear-50]
_ = x[CMDMerkleBlock-56]
_ = x[CMDAlert-64]
}
const (
_CommandType_name_0 = "CMDVersionCMDVerack"
_CommandType_name_1 = "CMDGetAddrCMDAddr"
_CommandType_name_2 = "CMDPingCMDPong"
_CommandType_name_3 = "CMDGetHeadersCMDHeaders"
_CommandType_name_4 = "CMDGetBlocksCMDMempool"
_CommandType_name_5 = "CMDInvCMDGetData"
_CommandType_name_6 = "CMDUnknownCMDTXCMDBlockCMDConsensus"
_CommandType_name_7 = "CMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
_CommandType_name_8 = "CMDMerkleBlock"
_CommandType_name_9 = "CMDAlert"
)
var (
_CommandType_index_0 = [...]uint8{0, 10, 19}
_CommandType_index_1 = [...]uint8{0, 10, 17}
_CommandType_index_2 = [...]uint8{0, 7, 14}
_CommandType_index_3 = [...]uint8{0, 13, 23}
_CommandType_index_4 = [...]uint8{0, 12, 22}
_CommandType_index_5 = [...]uint8{0, 6, 16}
_CommandType_index_6 = [...]uint8{0, 10, 15, 23, 35}
_CommandType_index_7 = [...]uint8{0, 9, 22, 34, 48}
)
func (i CommandType) String() string {
switch {
case i <= 1:
return _CommandType_name_0[_CommandType_index_0[i]:_CommandType_index_0[i+1]]
case 16 <= i && i <= 17:
i -= 16
return _CommandType_name_1[_CommandType_index_1[i]:_CommandType_index_1[i+1]]
case 24 <= i && i <= 25:
i -= 24
return _CommandType_name_2[_CommandType_index_2[i]:_CommandType_index_2[i+1]]
case 32 <= i && i <= 33:
i -= 32
return _CommandType_name_3[_CommandType_index_3[i]:_CommandType_index_3[i+1]]
case 36 <= i && i <= 37:
i -= 36
return _CommandType_name_4[_CommandType_index_4[i]:_CommandType_index_4[i+1]]
case 39 <= i && i <= 40:
i -= 39
return _CommandType_name_5[_CommandType_index_5[i]:_CommandType_index_5[i+1]]
case 42 <= i && i <= 45:
i -= 42
return _CommandType_name_6[_CommandType_index_6[i]:_CommandType_index_6[i+1]]
case 47 <= i && i <= 50:
i -= 47
return _CommandType_name_7[_CommandType_index_7[i]:_CommandType_index_7[i+1]]
case i == 56:
return _CommandType_name_8
case i == 64:
return _CommandType_name_9
default:
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View file

@ -671,7 +671,7 @@ func (s *Server) requestBlocks(p Peer) error {
func (s *Server) handleMessage(peer Peer, msg *Message) error {
s.log.Debug("got msg",
zap.Stringer("addr", peer.RemoteAddr()),
zap.String("type", string(msg.CommandType())))
zap.String("type", msg.Command.String()))
// Make sure both server and peer are operating on
// the same network.
@ -685,7 +685,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return errInvalidInvType
}
}
switch msg.CommandType() {
switch msg.Command {
case CMDAddr:
addrs := msg.Payload.(*payload.AddressList)
return s.handleAddrCmd(peer, addrs)
@ -723,10 +723,10 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
pong := msg.Payload.(*payload.Ping)
return s.handlePong(peer, pong)
case CMDVersion, CMDVerack:
return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
return fmt.Errorf("received '%s' after the handshake", msg.Command.String())
}
} else {
switch msg.CommandType() {
switch msg.Command {
case CMDVersion:
version := msg.Payload.(*payload.Version)
return s.handleVersionCmd(peer, version)
@ -739,7 +739,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
s.tryStartConsensus()
default:
return fmt.Errorf("received '%s' during handshake", msg.CommandType())
return fmt.Errorf("received '%s' during handshake", msg.Command.String())
}
}
return nil

View file

@ -18,7 +18,7 @@ func TestSendVersion(t *testing.T) {
s.UserAgent = "/test/"
p.messageHandler = func(t *testing.T, msg *Message) {
assert.Equal(t, CMDVersion, msg.CommandType())
assert.Equal(t, CMDVersion, msg.Command)
assert.IsType(t, msg.Payload, &payload.Version{})
version := msg.Payload.(*payload.Version)
assert.NotZero(t, version.Nonce)
@ -43,7 +43,7 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) {
// Should have a verack
p.messageHandler = func(t *testing.T, msg *Message) {
assert.Equal(t, CMDVerack, msg.CommandType())
assert.Equal(t, CMDVerack, msg.Command)
}
version := payload.NewVersion(1337, 3000, "/NEO-GO/", 0, true)
@ -101,7 +101,7 @@ func TestRequestHeaders(t *testing.T) {
)
p.messageHandler = func(t *testing.T, msg *Message) {
assert.IsType(t, &payload.GetBlocks{}, msg.Payload)
assert.Equal(t, CMDGetHeaders, msg.CommandType())
assert.Equal(t, CMDGetHeaders, msg.Command)
}
s.requestHeaders(p)
}

View file

@ -159,7 +159,7 @@ func (p *TCPPeer) handleConn() {
}
if err = p.server.handleMessage(p, msg); err != nil {
if p.Handshaked() {
err = fmt.Errorf("handling %s message: %v", msg.CommandType(), err)
err = fmt.Errorf("handling %s message: %v", msg.Command.String(), err)
}
break
}