mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-27 23:31:35 +00:00
148 lines
3.3 KiB
Go
148 lines
3.3 KiB
Go
package wire
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
|
|
"github.com/CityOfZion/neo-go/pkg/wire/payload/transaction"
|
|
checksum "github.com/CityOfZion/neo-go/pkg/wire/util/Checksum"
|
|
|
|
"github.com/CityOfZion/neo-go/pkg/wire/command"
|
|
"github.com/CityOfZion/neo-go/pkg/wire/payload"
|
|
"github.com/CityOfZion/neo-go/pkg/wire/protocol"
|
|
"github.com/CityOfZion/neo-go/pkg/wire/util"
|
|
)
|
|
|
|
type Messager interface {
|
|
EncodePayload(w io.Writer) error
|
|
DecodePayload(r io.Reader) error
|
|
Command() command.Type
|
|
}
|
|
|
|
const (
|
|
// Magic + cmd + length + checksum
|
|
minMsgSize = 4 + 12 + 4 + 4
|
|
)
|
|
|
|
var (
|
|
errChecksumMismatch = errors.New("checksum mismatch")
|
|
)
|
|
|
|
func WriteMessage(w io.Writer, magic protocol.Magic, message Messager) error {
|
|
bw := &util.BinWriter{W: w}
|
|
bw.Write(magic)
|
|
bw.Write(cmdToByteArray(message.Command()))
|
|
|
|
buf := new(bytes.Buffer)
|
|
if err := message.EncodePayload(buf); err != nil {
|
|
return err
|
|
}
|
|
|
|
payloadLen := util.BufferLength(buf)
|
|
checksum := checksum.FromBytes(buf.Bytes())
|
|
|
|
bw.Write(payloadLen)
|
|
bw.Write(checksum)
|
|
|
|
bw.WriteBigEnd(buf.Bytes())
|
|
|
|
return bw.Err
|
|
}
|
|
|
|
func ReadMessage(r io.Reader, magic protocol.Magic) (Messager, error) {
|
|
|
|
byt := make([]byte, minMsgSize)
|
|
|
|
if _, err := io.ReadFull(r, byt); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reader := bytes.NewReader(byt)
|
|
|
|
var header Base
|
|
_, err := header.DecodeBase(reader)
|
|
|
|
if err != nil {
|
|
return nil, errors.New("Error decoding into the header base")
|
|
}
|
|
|
|
buf := new(bytes.Buffer)
|
|
|
|
_, err = io.CopyN(buf, r, int64(header.PayloadLength))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Compare the checksum of the payload.
|
|
if !checksum.Compare(header.Checksum, buf.Bytes()) {
|
|
return nil, errChecksumMismatch
|
|
}
|
|
switch header.CMD {
|
|
case command.Version:
|
|
v := &payload.VersionMessage{}
|
|
err := v.DecodePayload(buf)
|
|
return v, err
|
|
case command.Verack:
|
|
v, err := payload.NewVerackMessage()
|
|
err = v.DecodePayload(buf)
|
|
return v, err
|
|
case command.Inv:
|
|
v, err := payload.NewInvMessage(0)
|
|
err = v.DecodePayload(buf)
|
|
return v, err
|
|
case command.GetAddr:
|
|
v, err := payload.NewGetAddrMessage()
|
|
err = v.DecodePayload(buf)
|
|
return v, err
|
|
case command.Addr:
|
|
v, err := payload.NewAddrMessage()
|
|
err = v.DecodePayload(buf)
|
|
return v, err
|
|
case command.Block:
|
|
v, err := payload.NewBlockMessage()
|
|
err = v.DecodePayload(buf)
|
|
return v, err
|
|
case command.GetBlocks:
|
|
v, err := payload.NewGetBlocksMessage([]util.Uint256{}, util.Uint256{})
|
|
err = v.DecodePayload(buf)
|
|
return v, err
|
|
case command.GetData:
|
|
v, err := payload.NewGetDataMessage(payload.InvTypeTx)
|
|
err = v.DecodePayload(buf)
|
|
return v, err
|
|
case command.GetHeaders:
|
|
v, err := payload.NewGetHeadersMessage([]util.Uint256{}, util.Uint256{})
|
|
err = v.DecodePayload(buf)
|
|
return v, err
|
|
case command.Headers:
|
|
v, err := payload.NewHeadersMessage()
|
|
err = v.DecodePayload(buf)
|
|
return v, err
|
|
case command.TX:
|
|
reader := bufio.NewReader(buf)
|
|
tx, err := transaction.FromBytes(reader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return payload.NewTXMessage(tx)
|
|
}
|
|
return nil, errors.New("Unknown Message found")
|
|
|
|
}
|
|
|
|
func cmdToByteArray(cmd command.Type) [command.Size]byte {
|
|
cmdLen := len(cmd)
|
|
if cmdLen > command.Size {
|
|
panic("exceeded command max length of size 12")
|
|
}
|
|
|
|
// The command can have max 12 bytes, rest is filled with 0.
|
|
b := [command.Size]byte{}
|
|
for i := 0; i < cmdLen; i++ {
|
|
b[i] = cmd[i]
|
|
}
|
|
|
|
return b
|
|
}
|