From 9b8b77c9ea9163ee0a37bbe494540804e6badbdf Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Mon, 30 Dec 2019 15:38:23 +0300 Subject: [PATCH] network: return error if header message is too big Big messages can still be processed but only first 2000 headers will be used. --- pkg/network/message.go | 8 ++-- pkg/network/payload/headers.go | 12 ++++- pkg/network/payload/headers_test.go | 72 ++++++++++++++++------------- pkg/network/tcp_transport.go | 8 +++- 4 files changed, 62 insertions(+), 38 deletions(-) diff --git a/pkg/network/message.go b/pkg/network/message.go index 16802a2ba..087724967 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -201,13 +201,11 @@ func (m *Message) decodePayload(br *io.BinReader) error { return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command)) } p.DecodeBinary(r) - if r.Err != nil { - return r.Err + if r.Err == nil || r.Err == payload.ErrTooManyHeaders { + m.Payload = p } - m.Payload = p - - return nil + return r.Err } // Encode encodes a Message to any given BinWriter. diff --git a/pkg/network/payload/headers.go b/pkg/network/payload/headers.go index 33f7b8e53..57160c190 100644 --- a/pkg/network/payload/headers.go +++ b/pkg/network/payload/headers.go @@ -3,6 +3,7 @@ package payload import ( "github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/io" + "github.com/pkg/errors" ) // Headers payload. @@ -15,12 +16,17 @@ const ( MaxHeadersAllowed = 2000 ) +// ErrTooManyHeaders is an error returned when too many headers were received. +var ErrTooManyHeaders = errors.Errorf("too many headers were received (max: %d)", MaxHeadersAllowed) + // DecodeBinary implements Serializable interface. func (p *Headers) DecodeBinary(br *io.BinReader) { lenHeaders := br.ReadVarUint() + var limitExceeded bool + // C# node does it silently - if lenHeaders > MaxHeadersAllowed { + if limitExceeded = lenHeaders > MaxHeadersAllowed; limitExceeded { lenHeaders = MaxHeadersAllowed } @@ -31,6 +37,10 @@ func (p *Headers) DecodeBinary(br *io.BinReader) { header.DecodeBinary(br) p.Hdrs[i] = header } + + if br.Err == nil && limitExceeded { + br.Err = ErrTooManyHeaders + } } // EncodeBinary implements Serializable interface. diff --git a/pkg/network/payload/headers_test.go b/pkg/network/payload/headers_test.go index 44c2b001f..22762cc0e 100644 --- a/pkg/network/payload/headers_test.go +++ b/pkg/network/payload/headers_test.go @@ -11,36 +11,39 @@ import ( ) func TestHeadersEncodeDecode(t *testing.T) { - headers := &Headers{[]*core.Header{ - { - BlockBase: core.BlockBase{ - Version: 0, - Index: 1, - Script: transaction.Witness{ - InvocationScript: []byte{0x0}, - VerificationScript: []byte{0x1}, - }, - }}, - { - BlockBase: core.BlockBase{ - Version: 0, - Index: 2, - Script: transaction.Witness{ - InvocationScript: []byte{0x0}, - VerificationScript: []byte{0x1}, - }, - }}, - { - BlockBase: core.BlockBase{ - Version: 0, - Index: 3, - Script: transaction.Witness{ - InvocationScript: []byte{0x0}, - VerificationScript: []byte{0x1}, - }, - }}, - }} + t.Run("normal case", func(t *testing.T) { + headers := newTestHeaders(3) + testHeadersEncodeDecode(t, headers, 3, false) + }) + + t.Run("more than max", func(t *testing.T) { + const sent = MaxHeadersAllowed + 1 + headers := newTestHeaders(sent) + + testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, true) + }) +} + +func newTestHeaders(n int) *Headers { + headers := &Headers{Hdrs: make([]*core.Header, n)} + + for i := range headers.Hdrs { + headers.Hdrs[i] = &core.Header{ + BlockBase: core.BlockBase{ + Index: uint32(i + 1), + Script: transaction.Witness{ + InvocationScript: []byte{0x0}, + VerificationScript: []byte{0x1}, + }, + }, + } + } + + return headers +} + +func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, limit bool) { buf := io.NewBufBinWriter() headers.EncodeBinary(buf.BinWriter) assert.Nil(t, buf.Err) @@ -49,9 +52,16 @@ func TestHeadersEncodeDecode(t *testing.T) { r := io.NewBinReaderFromBuf(b) headersDecode := &Headers{} headersDecode.DecodeBinary(r) - assert.Nil(t, r.Err) - for i := 0; i < len(headers.Hdrs); i++ { + var err error + if limit { + err = ErrTooManyHeaders + } + + assert.Equal(t, err, r.Err) + assert.Equal(t, expected, len(headersDecode.Hdrs)) + + for i := 0; i < len(headersDecode.Hdrs); i++ { assert.Equal(t, headers.Hdrs[i].Version, headersDecode.Hdrs[i].Version) assert.Equal(t, headers.Hdrs[i].Index, headersDecode.Hdrs[i].Index) assert.Equal(t, headers.Hdrs[i].Script, headersDecode.Hdrs[i].Script) diff --git a/pkg/network/tcp_transport.go b/pkg/network/tcp_transport.go index f41015f53..2055529d6 100644 --- a/pkg/network/tcp_transport.go +++ b/pkg/network/tcp_transport.go @@ -6,6 +6,7 @@ import ( "time" "github.com/CityOfZion/neo-go/pkg/io" + "github.com/CityOfZion/neo-go/pkg/network/payload" "go.uber.org/zap" ) @@ -87,7 +88,12 @@ func (t *TCPTransport) handleConn(conn net.Conn) { r := io.NewBinReaderFromIO(p.conn) for { msg := &Message{} - if err = msg.Decode(r); err != nil { + err := msg.Decode(r) + + if err == payload.ErrTooManyHeaders { + t.log.Warn("not all headers were processed") + r.Err = nil + } else if err != nil { break } if err = t.server.handleMessage(p, msg); err != nil {