neo-go/pkg/wire/message.go
2019-02-25 22:44:14 +00:00

148 lines
3.3 KiB
Go

package wire
import (
"bufio"
"bytes"
"errors"
"io"
"github.com/CityOfZion/neo-go/pkg/wire/payload/transaction"
checksum "github.com/CityOfZion/neo-go/pkg/wire/util/Checksum"
"github.com/CityOfZion/neo-go/pkg/wire/command"
"github.com/CityOfZion/neo-go/pkg/wire/payload"
"github.com/CityOfZion/neo-go/pkg/wire/protocol"
"github.com/CityOfZion/neo-go/pkg/wire/util"
)
type Messager interface {
EncodePayload(w io.Writer) error
DecodePayload(r io.Reader) error
Command() command.Type
}
const (
// Magic + cmd + length + checksum
minMsgSize = 4 + 12 + 4 + 4
)
var (
errChecksumMismatch = errors.New("checksum mismatch")
)
func WriteMessage(w io.Writer, magic protocol.Magic, message Messager) error {
bw := &util.BinWriter{W: w}
bw.Write(magic)
bw.Write(cmdToByteArray(message.Command()))
buf := new(bytes.Buffer)
if err := message.EncodePayload(buf); err != nil {
return err
}
payloadLen := util.BufferLength(buf)
checksum := checksum.FromBytes(buf.Bytes())
bw.Write(payloadLen)
bw.Write(checksum)
bw.WriteBigEnd(buf.Bytes())
return bw.Err
}
func ReadMessage(r io.Reader, magic protocol.Magic) (Messager, error) {
byt := make([]byte, minMsgSize)
if _, err := io.ReadFull(r, byt); err != nil {
return nil, err
}
reader := bytes.NewReader(byt)
var header Base
_, err := header.DecodeBase(reader)
if err != nil {
return nil, errors.New("Error decoding into the header base")
}
buf := new(bytes.Buffer)
_, err = io.CopyN(buf, r, int64(header.PayloadLength))
if err != nil {
return nil, err
}
// Compare the checksum of the payload.
if !checksum.Compare(header.Checksum, buf.Bytes()) {
return nil, errChecksumMismatch
}
switch header.CMD {
case command.Version:
v := &payload.VersionMessage{}
err := v.DecodePayload(buf)
return v, err
case command.Verack:
v, err := payload.NewVerackMessage()
err = v.DecodePayload(buf)
return v, err
case command.Inv:
v, err := payload.NewInvMessage(0)
err = v.DecodePayload(buf)
return v, err
case command.GetAddr:
v, err := payload.NewGetAddrMessage()
err = v.DecodePayload(buf)
return v, err
case command.Addr:
v, err := payload.NewAddrMessage()
err = v.DecodePayload(buf)
return v, err
case command.Block:
v, err := payload.NewBlockMessage()
err = v.DecodePayload(buf)
return v, err
case command.GetBlocks:
v, err := payload.NewGetBlocksMessage([]util.Uint256{}, util.Uint256{})
err = v.DecodePayload(buf)
return v, err
case command.GetData:
v, err := payload.NewGetDataMessage(payload.InvTypeTx)
err = v.DecodePayload(buf)
return v, err
case command.GetHeaders:
v, err := payload.NewGetHeadersMessage([]util.Uint256{}, util.Uint256{})
err = v.DecodePayload(buf)
return v, err
case command.Headers:
v, err := payload.NewHeadersMessage()
err = v.DecodePayload(buf)
return v, err
case command.TX:
reader := bufio.NewReader(buf)
tx, err := transaction.FromBytes(reader)
if err != nil {
return nil, err
}
return payload.NewTXMessage(tx)
}
return nil, errors.New("Unknown Message found")
}
func cmdToByteArray(cmd command.Type) [command.Size]byte {
cmdLen := len(cmd)
if cmdLen > command.Size {
panic("exceeded command max length of size 12")
}
// The command can have max 12 bytes, rest is filled with 0.
b := [command.Size]byte{}
for i := 0; i < cmdLen; i++ {
b[i] = cmd[i]
}
return b
}