From eecd71abeb5d8318f61944d4e3ef274685b5cd5b Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Sat, 6 Feb 2021 00:06:01 +0300 Subject: [PATCH] payload: treat zero-length headers as error See neo-project/neo#2259. --- pkg/network/payload/headers.go | 9 +++++++++ pkg/network/payload/headers_test.go | 19 ++++++++++--------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/pkg/network/payload/headers.go b/pkg/network/payload/headers.go index 61b27053e..7c8c53d30 100644 --- a/pkg/network/payload/headers.go +++ b/pkg/network/payload/headers.go @@ -1,6 +1,7 @@ package payload import ( + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/config/netmode" @@ -24,10 +25,18 @@ const ( // ErrTooManyHeaders is an error returned when too many headers were received. var ErrTooManyHeaders = fmt.Errorf("too many headers were received (max: %d)", MaxHeadersAllowed) +// ErrNoHeaders is returned for zero-elements Headers payload which is considered to be invalid. +var ErrNoHeaders = errors.New("no headers (zero length array)") + // DecodeBinary implements Serializable interface. func (p *Headers) DecodeBinary(br *io.BinReader) { lenHeaders := br.ReadVarUint() + if br.Err == nil && lenHeaders == 0 { + br.Err = ErrNoHeaders + return + } + var limitExceeded bool // C# node does it silently diff --git a/pkg/network/payload/headers_test.go b/pkg/network/payload/headers_test.go index 38ef5183b..a8ea27e8b 100644 --- a/pkg/network/payload/headers_test.go +++ b/pkg/network/payload/headers_test.go @@ -13,14 +13,20 @@ func TestHeadersEncodeDecode(t *testing.T) { t.Run("normal case", func(t *testing.T) { headers := newTestHeaders(3) - testHeadersEncodeDecode(t, headers, 3, false) + testHeadersEncodeDecode(t, headers, 3, nil) }) t.Run("more than max", func(t *testing.T) { const sent = MaxHeadersAllowed + 1 headers := newTestHeaders(sent) - testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, true) + testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, ErrTooManyHeaders) + }) + + t.Run("zero", func(t *testing.T) { + headers := newTestHeaders(0) + + testHeadersEncodeDecode(t, headers, 0, ErrNoHeaders) }) } @@ -42,19 +48,14 @@ func newTestHeaders(n int) *Headers { return headers } -func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, limit bool) { +func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, retErr error) { data, err := testserdes.EncodeBinary(headers) assert.Nil(t, err) headersDecode := &Headers{} rErr := testserdes.DecodeBinary(data, headersDecode) - err = nil - if limit { - err = ErrTooManyHeaders - } - - assert.Equal(t, err, rErr) + assert.Equal(t, retErr, rErr) assert.Equal(t, expected, len(headersDecode.Hdrs)) for i := 0; i < len(headersDecode.Hdrs); i++ {