network: return error if header message is too big

Big messages can still be processed but only first
2000 headers will be used.
This commit is contained in:
Evgenii Stratonikov 2019-12-30 15:38:23 +03:00
parent 637c99eda7
commit 9b8b77c9ea
4 changed files with 62 additions and 38 deletions

View file

@ -201,13 +201,11 @@ func (m *Message) decodePayload(br *io.BinReader) error {
return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command))
}
p.DecodeBinary(r)
if r.Err != nil {
return r.Err
if r.Err == nil || r.Err == payload.ErrTooManyHeaders {
m.Payload = p
}
m.Payload = p
return nil
return r.Err
}
// Encode encodes a Message to any given BinWriter.

View file

@ -3,6 +3,7 @@ package payload
import (
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/io"
"github.com/pkg/errors"
)
// Headers payload.
@ -15,12 +16,17 @@ const (
MaxHeadersAllowed = 2000
)
// ErrTooManyHeaders is an error returned when too many headers were received.
var ErrTooManyHeaders = errors.Errorf("too many headers were received (max: %d)", MaxHeadersAllowed)
// DecodeBinary implements Serializable interface.
func (p *Headers) DecodeBinary(br *io.BinReader) {
lenHeaders := br.ReadVarUint()
var limitExceeded bool
// C# node does it silently
if lenHeaders > MaxHeadersAllowed {
if limitExceeded = lenHeaders > MaxHeadersAllowed; limitExceeded {
lenHeaders = MaxHeadersAllowed
}
@ -31,6 +37,10 @@ func (p *Headers) DecodeBinary(br *io.BinReader) {
header.DecodeBinary(br)
p.Hdrs[i] = header
}
if br.Err == nil && limitExceeded {
br.Err = ErrTooManyHeaders
}
}
// EncodeBinary implements Serializable interface.

View file

@ -11,36 +11,39 @@ import (
)
func TestHeadersEncodeDecode(t *testing.T) {
headers := &Headers{[]*core.Header{
{
BlockBase: core.BlockBase{
Version: 0,
Index: 1,
Script: transaction.Witness{
InvocationScript: []byte{0x0},
VerificationScript: []byte{0x1},
},
}},
{
BlockBase: core.BlockBase{
Version: 0,
Index: 2,
Script: transaction.Witness{
InvocationScript: []byte{0x0},
VerificationScript: []byte{0x1},
},
}},
{
BlockBase: core.BlockBase{
Version: 0,
Index: 3,
Script: transaction.Witness{
InvocationScript: []byte{0x0},
VerificationScript: []byte{0x1},
},
}},
}}
t.Run("normal case", func(t *testing.T) {
headers := newTestHeaders(3)
testHeadersEncodeDecode(t, headers, 3, false)
})
t.Run("more than max", func(t *testing.T) {
const sent = MaxHeadersAllowed + 1
headers := newTestHeaders(sent)
testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, true)
})
}
func newTestHeaders(n int) *Headers {
headers := &Headers{Hdrs: make([]*core.Header, n)}
for i := range headers.Hdrs {
headers.Hdrs[i] = &core.Header{
BlockBase: core.BlockBase{
Index: uint32(i + 1),
Script: transaction.Witness{
InvocationScript: []byte{0x0},
VerificationScript: []byte{0x1},
},
},
}
}
return headers
}
func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, limit bool) {
buf := io.NewBufBinWriter()
headers.EncodeBinary(buf.BinWriter)
assert.Nil(t, buf.Err)
@ -49,9 +52,16 @@ func TestHeadersEncodeDecode(t *testing.T) {
r := io.NewBinReaderFromBuf(b)
headersDecode := &Headers{}
headersDecode.DecodeBinary(r)
assert.Nil(t, r.Err)
for i := 0; i < len(headers.Hdrs); i++ {
var err error
if limit {
err = ErrTooManyHeaders
}
assert.Equal(t, err, r.Err)
assert.Equal(t, expected, len(headersDecode.Hdrs))
for i := 0; i < len(headersDecode.Hdrs); i++ {
assert.Equal(t, headers.Hdrs[i].Version, headersDecode.Hdrs[i].Version)
assert.Equal(t, headers.Hdrs[i].Index, headersDecode.Hdrs[i].Index)
assert.Equal(t, headers.Hdrs[i].Script, headersDecode.Hdrs[i].Script)

View file

@ -6,6 +6,7 @@ import (
"time"
"github.com/CityOfZion/neo-go/pkg/io"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"go.uber.org/zap"
)
@ -87,7 +88,12 @@ func (t *TCPTransport) handleConn(conn net.Conn) {
r := io.NewBinReaderFromIO(p.conn)
for {
msg := &Message{}
if err = msg.Decode(r); err != nil {
err := msg.Decode(r)
if err == payload.ErrTooManyHeaders {
t.log.Warn("not all headers were processed")
r.Err = nil
} else if err != nil {
break
}
if err = t.server.handleMessage(p, msg); err != nil {