parent
aad2b3adad
commit
eecd71abeb
2 changed files with 19 additions and 9 deletions
|
@ -1,6 +1,7 @@
|
||||||
package payload
|
package payload
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
|
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
|
||||||
|
@ -24,10 +25,18 @@ const (
|
||||||
// ErrTooManyHeaders is an error returned when too many headers were received.
|
// ErrTooManyHeaders is an error returned when too many headers were received.
|
||||||
var ErrTooManyHeaders = fmt.Errorf("too many headers were received (max: %d)", MaxHeadersAllowed)
|
var ErrTooManyHeaders = fmt.Errorf("too many headers were received (max: %d)", MaxHeadersAllowed)
|
||||||
|
|
||||||
|
// ErrNoHeaders is returned for zero-elements Headers payload which is considered to be invalid.
|
||||||
|
var ErrNoHeaders = errors.New("no headers (zero length array)")
|
||||||
|
|
||||||
// DecodeBinary implements Serializable interface.
|
// DecodeBinary implements Serializable interface.
|
||||||
func (p *Headers) DecodeBinary(br *io.BinReader) {
|
func (p *Headers) DecodeBinary(br *io.BinReader) {
|
||||||
lenHeaders := br.ReadVarUint()
|
lenHeaders := br.ReadVarUint()
|
||||||
|
|
||||||
|
if br.Err == nil && lenHeaders == 0 {
|
||||||
|
br.Err = ErrNoHeaders
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var limitExceeded bool
|
var limitExceeded bool
|
||||||
|
|
||||||
// C# node does it silently
|
// C# node does it silently
|
||||||
|
|
|
@ -13,14 +13,20 @@ func TestHeadersEncodeDecode(t *testing.T) {
|
||||||
t.Run("normal case", func(t *testing.T) {
|
t.Run("normal case", func(t *testing.T) {
|
||||||
headers := newTestHeaders(3)
|
headers := newTestHeaders(3)
|
||||||
|
|
||||||
testHeadersEncodeDecode(t, headers, 3, false)
|
testHeadersEncodeDecode(t, headers, 3, nil)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("more than max", func(t *testing.T) {
|
t.Run("more than max", func(t *testing.T) {
|
||||||
const sent = MaxHeadersAllowed + 1
|
const sent = MaxHeadersAllowed + 1
|
||||||
headers := newTestHeaders(sent)
|
headers := newTestHeaders(sent)
|
||||||
|
|
||||||
testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, true)
|
testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, ErrTooManyHeaders)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("zero", func(t *testing.T) {
|
||||||
|
headers := newTestHeaders(0)
|
||||||
|
|
||||||
|
testHeadersEncodeDecode(t, headers, 0, ErrNoHeaders)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,19 +48,14 @@ func newTestHeaders(n int) *Headers {
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
|
||||||
func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, limit bool) {
|
func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, retErr error) {
|
||||||
data, err := testserdes.EncodeBinary(headers)
|
data, err := testserdes.EncodeBinary(headers)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
headersDecode := &Headers{}
|
headersDecode := &Headers{}
|
||||||
rErr := testserdes.DecodeBinary(data, headersDecode)
|
rErr := testserdes.DecodeBinary(data, headersDecode)
|
||||||
|
|
||||||
err = nil
|
assert.Equal(t, retErr, rErr)
|
||||||
if limit {
|
|
||||||
err = ErrTooManyHeaders
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, err, rErr)
|
|
||||||
assert.Equal(t, expected, len(headersDecode.Hdrs))
|
assert.Equal(t, expected, len(headersDecode.Hdrs))
|
||||||
|
|
||||||
for i := 0; i < len(headersDecode.Hdrs); i++ {
|
for i := 0; i < len(headersDecode.Hdrs); i++ {
|
||||||
|
|
Loading…
Reference in a new issue