payload: treat zero-length headers as error

See neo-project/neo#2259.
This commit is contained in:
Roman Khimov 2021-02-06 00:06:01 +03:00
parent aad2b3adad
commit eecd71abeb
2 changed files with 19 additions and 9 deletions

View file

@ -1,6 +1,7 @@
package payload
import (
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
@ -24,10 +25,18 @@ const (
// ErrTooManyHeaders is an error returned when too many headers were received.
var ErrTooManyHeaders = fmt.Errorf("too many headers were received (max: %d)", MaxHeadersAllowed)
// ErrNoHeaders is returned for zero-elements Headers payload which is considered to be invalid.
var ErrNoHeaders = errors.New("no headers (zero length array)")
// DecodeBinary implements Serializable interface.
func (p *Headers) DecodeBinary(br *io.BinReader) {
lenHeaders := br.ReadVarUint()
if br.Err == nil && lenHeaders == 0 {
br.Err = ErrNoHeaders
return
}
var limitExceeded bool
// C# node does it silently

View file

@ -13,14 +13,20 @@ func TestHeadersEncodeDecode(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
headers := newTestHeaders(3)
testHeadersEncodeDecode(t, headers, 3, false)
testHeadersEncodeDecode(t, headers, 3, nil)
})
t.Run("more than max", func(t *testing.T) {
const sent = MaxHeadersAllowed + 1
headers := newTestHeaders(sent)
testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, true)
testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, ErrTooManyHeaders)
})
t.Run("zero", func(t *testing.T) {
headers := newTestHeaders(0)
testHeadersEncodeDecode(t, headers, 0, ErrNoHeaders)
})
}
@ -42,19 +48,14 @@ func newTestHeaders(n int) *Headers {
return headers
}
func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, limit bool) {
func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, retErr error) {
data, err := testserdes.EncodeBinary(headers)
assert.Nil(t, err)
headersDecode := &Headers{}
rErr := testserdes.DecodeBinary(data, headersDecode)
err = nil
if limit {
err = ErrTooManyHeaders
}
assert.Equal(t, err, rErr)
assert.Equal(t, retErr, rErr)
assert.Equal(t, expected, len(headersDecode.Hdrs))
for i := 0; i < len(headersDecode.Hdrs); i++ {