protocol: switch to binary MessageCommand

closes #888
This commit is contained in:
Anna Shaleva 2020-05-19 14:54:51 +03:00
parent 1317666167
commit 3bcc56bdcf
5 changed files with 140 additions and 117 deletions

View file

@ -11,20 +11,15 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
) )
const ( //go:generate stringer -type=CommandType
// The minimum size of a valid message.
minMessageSize = 24
cmdSize = 12
)
// Message is the complete message send between nodes. // Message is the complete message send between nodes.
type Message struct { type Message struct {
// NetMode of the node that sends this message. // NetMode of the node that sends this message.
Magic config.NetMode Magic config.NetMode
// Command is utf8 code, of which the length is 12 bytes, // Command is byte command code.
// the extra part is filled with 0. Command CommandType
Command [cmdSize]byte
// Length of the payload. // Length of the payload.
Length uint32 Length uint32
@ -34,30 +29,41 @@ type Message struct {
} }
// CommandType represents the type of a message command. // CommandType represents the type of a message command.
type CommandType string type CommandType byte
// Valid protocol commands used to send between nodes. // Valid protocol commands used to send between nodes.
const ( const (
CMDAddr CommandType = "addr" // handshaking
CMDBlock CommandType = "block" CMDVersion CommandType = 0x00
CMDConsensus CommandType = "consensus" CMDVerack CommandType = 0x01
CMDFilterAdd CommandType = "filteradd"
CMDFilterClear CommandType = "filterclear" // connectivity
CMDFilterLoad CommandType = "filterload" CMDGetAddr CommandType = 0x10
CMDGetAddr CommandType = "getaddr" CMDAddr CommandType = 0x11
CMDGetBlocks CommandType = "getblocks" CMDPing CommandType = 0x18
CMDGetData CommandType = "getdata" CMDPong CommandType = 0x19
CMDGetHeaders CommandType = "getheaders"
CMDHeaders CommandType = "headers" // synchronization
CMDInv CommandType = "inv" CMDGetHeaders CommandType = 0x20
CMDMempool CommandType = "mempool" CMDHeaders CommandType = 0x21
CMDMerkleBlock CommandType = "merkleblock" CMDGetBlocks CommandType = 0x24
CMDPing CommandType = "ping" CMDMempool CommandType = 0x25
CMDPong CommandType = "pong" CMDInv CommandType = 0x27
CMDTX CommandType = "tx" CMDGetData CommandType = 0x28
CMDUnknown CommandType = "unknown" CMDUnknown CommandType = 0x2a
CMDVerack CommandType = "verack" CMDTX CommandType = 0x2b
CMDVersion CommandType = "version" CMDBlock CommandType = 0x2c
CMDConsensus CommandType = 0x2d
CMDReject CommandType = 0x2f
// SPV protocol
CMDFilterLoad CommandType = 0x30
CMDFilterAdd CommandType = 0x31
CMDFilterClear CommandType = 0x32
CMDMerkleBlock CommandType = 0x38
// others
CMDAlert CommandType = 0x40
) )
// NewMessage returns a new message with the given payload. // NewMessage returns a new message with the given payload.
@ -78,63 +84,16 @@ func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Messa
return &Message{ return &Message{
Magic: magic, Magic: magic,
Command: cmdToByteArray(cmd), Command: cmd,
Length: size, Length: size,
Payload: p, Payload: p,
} }
} }
// CommandType converts the 12 byte command slice to a CommandType.
func (m *Message) CommandType() CommandType {
cmd := cmdByteArrayToString(m.Command)
switch cmd {
case "addr":
return CMDAddr
case "block":
return CMDBlock
case "consensus":
return CMDConsensus
case "filteradd":
return CMDFilterAdd
case "filterclear":
return CMDFilterClear
case "filterload":
return CMDFilterLoad
case "getaddr":
return CMDGetAddr
case "getblocks":
return CMDGetBlocks
case "getdata":
return CMDGetData
case "getheaders":
return CMDGetHeaders
case "headers":
return CMDHeaders
case "inv":
return CMDInv
case "mempool":
return CMDMempool
case "merkleblock":
return CMDMerkleBlock
case "ping":
return CMDPing
case "pong":
return CMDPong
case "tx":
return CMDTX
case "verack":
return CMDVerack
case "version":
return CMDVersion
default:
return CMDUnknown
}
}
// Decode decodes a Message from the given reader. // Decode decodes a Message from the given reader.
func (m *Message) Decode(br *io.BinReader) error { func (m *Message) Decode(br *io.BinReader) error {
m.Magic = config.NetMode(br.ReadU32LE()) m.Magic = config.NetMode(br.ReadU32LE())
br.ReadBytes(m.Command[:]) m.Command = CommandType(br.ReadB())
m.Length = br.ReadU32LE() m.Length = br.ReadU32LE()
if br.Err != nil { if br.Err != nil {
return br.Err return br.Err
@ -155,7 +114,7 @@ func (m *Message) decodePayload(br *io.BinReader) error {
r := io.NewBinReaderFromBuf(buf) r := io.NewBinReaderFromBuf(buf)
var p payload.Payload var p payload.Payload
switch m.CommandType() { switch m.Command {
case CMDVersion: case CMDVersion:
p = &payload.Version{} p = &payload.Version{}
case CMDInv, CMDGetData: case CMDInv, CMDGetData:
@ -179,7 +138,7 @@ func (m *Message) decodePayload(br *io.BinReader) error {
case CMDPing, CMDPong: case CMDPing, CMDPong:
p = &payload.Ping{} p = &payload.Ping{}
default: default:
return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command)) return fmt.Errorf("can't decode command %s", m.Command.String())
} }
p.DecodeBinary(r) p.DecodeBinary(r)
if r.Err == nil || r.Err == payload.ErrTooManyHeaders { if r.Err == nil || r.Err == payload.ErrTooManyHeaders {
@ -192,7 +151,7 @@ func (m *Message) decodePayload(br *io.BinReader) error {
// Encode encodes a Message to any given BinWriter. // Encode encodes a Message to any given BinWriter.
func (m *Message) Encode(br *io.BinWriter) error { func (m *Message) Encode(br *io.BinWriter) error {
br.WriteU32LE(uint32(m.Magic)) br.WriteU32LE(uint32(m.Magic))
br.WriteBytes(m.Command[:]) br.WriteB(byte(m.Command))
br.WriteU32LE(m.Length) br.WriteU32LE(m.Length)
if m.Payload != nil { if m.Payload != nil {
m.Payload.EncodeBinary(br) m.Payload.EncodeBinary(br)
@ -215,30 +174,3 @@ func (m *Message) Bytes() ([]byte, error) {
} }
return w.Bytes(), nil return w.Bytes(), nil
} }
// convert a command (string) to a byte slice filled with 0 bytes till
// size 12.
func cmdToByteArray(cmd CommandType) [cmdSize]byte {
cmdLen := len(cmd)
if cmdLen > cmdSize {
panic("exceeded command max length of size 12")
}
// The command can have max 12 bytes, rest is filled with 0.
b := [cmdSize]byte{}
for i := 0; i < cmdLen; i++ {
b[i] = cmd[i]
}
return b
}
func cmdByteArrayToString(cmd [cmdSize]byte) string {
buf := make([]byte, 0, cmdSize)
for i := 0; i < cmdSize; i++ {
if cmd[i] != 0 {
buf = append(buf, cmd[i])
}
}
return string(buf)
}

View file

@ -0,0 +1,91 @@
// Code generated by "stringer -type=CommandType"; DO NOT EDIT.
package network
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[CMDVersion-0]
_ = x[CMDVerack-1]
_ = x[CMDGetAddr-16]
_ = x[CMDAddr-17]
_ = x[CMDPing-24]
_ = x[CMDPong-25]
_ = x[CMDGetHeaders-32]
_ = x[CMDHeaders-33]
_ = x[CMDGetBlocks-36]
_ = x[CMDMempool-37]
_ = x[CMDInv-39]
_ = x[CMDGetData-40]
_ = x[CMDUnknown-42]
_ = x[CMDTX-43]
_ = x[CMDBlock-44]
_ = x[CMDConsensus-45]
_ = x[CMDReject-47]
_ = x[CMDFilterLoad-48]
_ = x[CMDFilterAdd-49]
_ = x[CMDFilterClear-50]
_ = x[CMDMerkleBlock-56]
_ = x[CMDAlert-64]
}
const (
_CommandType_name_0 = "CMDVersionCMDVerack"
_CommandType_name_1 = "CMDGetAddrCMDAddr"
_CommandType_name_2 = "CMDPingCMDPong"
_CommandType_name_3 = "CMDGetHeadersCMDHeaders"
_CommandType_name_4 = "CMDGetBlocksCMDMempool"
_CommandType_name_5 = "CMDInvCMDGetData"
_CommandType_name_6 = "CMDUnknownCMDTXCMDBlockCMDConsensus"
_CommandType_name_7 = "CMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
_CommandType_name_8 = "CMDMerkleBlock"
_CommandType_name_9 = "CMDAlert"
)
var (
_CommandType_index_0 = [...]uint8{0, 10, 19}
_CommandType_index_1 = [...]uint8{0, 10, 17}
_CommandType_index_2 = [...]uint8{0, 7, 14}
_CommandType_index_3 = [...]uint8{0, 13, 23}
_CommandType_index_4 = [...]uint8{0, 12, 22}
_CommandType_index_5 = [...]uint8{0, 6, 16}
_CommandType_index_6 = [...]uint8{0, 10, 15, 23, 35}
_CommandType_index_7 = [...]uint8{0, 9, 22, 34, 48}
)
func (i CommandType) String() string {
switch {
case i <= 1:
return _CommandType_name_0[_CommandType_index_0[i]:_CommandType_index_0[i+1]]
case 16 <= i && i <= 17:
i -= 16
return _CommandType_name_1[_CommandType_index_1[i]:_CommandType_index_1[i+1]]
case 24 <= i && i <= 25:
i -= 24
return _CommandType_name_2[_CommandType_index_2[i]:_CommandType_index_2[i+1]]
case 32 <= i && i <= 33:
i -= 32
return _CommandType_name_3[_CommandType_index_3[i]:_CommandType_index_3[i+1]]
case 36 <= i && i <= 37:
i -= 36
return _CommandType_name_4[_CommandType_index_4[i]:_CommandType_index_4[i+1]]
case 39 <= i && i <= 40:
i -= 39
return _CommandType_name_5[_CommandType_index_5[i]:_CommandType_index_5[i+1]]
case 42 <= i && i <= 45:
i -= 42
return _CommandType_name_6[_CommandType_index_6[i]:_CommandType_index_6[i+1]]
case 47 <= i && i <= 50:
i -= 47
return _CommandType_name_7[_CommandType_index_7[i]:_CommandType_index_7[i+1]]
case i == 56:
return _CommandType_name_8
case i == 64:
return _CommandType_name_9
default:
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View file

@ -671,7 +671,7 @@ func (s *Server) requestBlocks(p Peer) error {
func (s *Server) handleMessage(peer Peer, msg *Message) error { func (s *Server) handleMessage(peer Peer, msg *Message) error {
s.log.Debug("got msg", s.log.Debug("got msg",
zap.Stringer("addr", peer.RemoteAddr()), zap.Stringer("addr", peer.RemoteAddr()),
zap.String("type", string(msg.CommandType()))) zap.String("type", msg.Command.String()))
// Make sure both server and peer are operating on // Make sure both server and peer are operating on
// the same network. // the same network.
@ -685,7 +685,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return errInvalidInvType return errInvalidInvType
} }
} }
switch msg.CommandType() { switch msg.Command {
case CMDAddr: case CMDAddr:
addrs := msg.Payload.(*payload.AddressList) addrs := msg.Payload.(*payload.AddressList)
return s.handleAddrCmd(peer, addrs) return s.handleAddrCmd(peer, addrs)
@ -723,10 +723,10 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
pong := msg.Payload.(*payload.Ping) pong := msg.Payload.(*payload.Ping)
return s.handlePong(peer, pong) return s.handlePong(peer, pong)
case CMDVersion, CMDVerack: case CMDVersion, CMDVerack:
return fmt.Errorf("received '%s' after the handshake", msg.CommandType()) return fmt.Errorf("received '%s' after the handshake", msg.Command.String())
} }
} else { } else {
switch msg.CommandType() { switch msg.Command {
case CMDVersion: case CMDVersion:
version := msg.Payload.(*payload.Version) version := msg.Payload.(*payload.Version)
return s.handleVersionCmd(peer, version) return s.handleVersionCmd(peer, version)
@ -739,7 +739,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
s.tryStartConsensus() s.tryStartConsensus()
default: default:
return fmt.Errorf("received '%s' during handshake", msg.CommandType()) return fmt.Errorf("received '%s' during handshake", msg.Command.String())
} }
} }
return nil return nil

View file

@ -18,7 +18,7 @@ func TestSendVersion(t *testing.T) {
s.UserAgent = "/test/" s.UserAgent = "/test/"
p.messageHandler = func(t *testing.T, msg *Message) { p.messageHandler = func(t *testing.T, msg *Message) {
assert.Equal(t, CMDVersion, msg.CommandType()) assert.Equal(t, CMDVersion, msg.Command)
assert.IsType(t, msg.Payload, &payload.Version{}) assert.IsType(t, msg.Payload, &payload.Version{})
version := msg.Payload.(*payload.Version) version := msg.Payload.(*payload.Version)
assert.NotZero(t, version.Nonce) assert.NotZero(t, version.Nonce)
@ -43,7 +43,7 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) {
// Should have a verack // Should have a verack
p.messageHandler = func(t *testing.T, msg *Message) { p.messageHandler = func(t *testing.T, msg *Message) {
assert.Equal(t, CMDVerack, msg.CommandType()) assert.Equal(t, CMDVerack, msg.Command)
} }
version := payload.NewVersion(1337, 3000, "/NEO-GO/", 0, true) version := payload.NewVersion(1337, 3000, "/NEO-GO/", 0, true)
@ -101,7 +101,7 @@ func TestRequestHeaders(t *testing.T) {
) )
p.messageHandler = func(t *testing.T, msg *Message) { p.messageHandler = func(t *testing.T, msg *Message) {
assert.IsType(t, &payload.GetBlocks{}, msg.Payload) assert.IsType(t, &payload.GetBlocks{}, msg.Payload)
assert.Equal(t, CMDGetHeaders, msg.CommandType()) assert.Equal(t, CMDGetHeaders, msg.Command)
} }
s.requestHeaders(p) s.requestHeaders(p)
} }

View file

@ -159,7 +159,7 @@ func (p *TCPPeer) handleConn() {
} }
if err = p.server.handleMessage(p, msg); err != nil { if err = p.server.handleMessage(p, msg); err != nil {
if p.Handshaked() { if p.Handshaked() {
err = fmt.Errorf("handling %s message: %v", msg.CommandType(), err) err = fmt.Errorf("handling %s message: %v", msg.Command.String(), err)
} }
break break
} }