io: restrict ReadArray max array size
This commit is contained in:
parent
2679d3fa35
commit
f65545023d
2 changed files with 29 additions and 2 deletions
|
@ -3,10 +3,15 @@ package io
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// maxArraySize is a maximums size of an array which can be decoded.
|
||||||
|
// It is taken from https://github.com/neo-project/neo/blob/master/neo/IO/Helper.cs#L130
|
||||||
|
const maxArraySize = 0x1000000
|
||||||
|
|
||||||
// BinReader is a convenient wrapper around a io.Reader and err object.
|
// BinReader is a convenient wrapper around a io.Reader and err object.
|
||||||
// Used to simplify error handling when reading into a struct with many fields.
|
// Used to simplify error handling when reading into a struct with many fields.
|
||||||
type BinReader struct {
|
type BinReader struct {
|
||||||
|
@ -36,7 +41,7 @@ func (r *BinReader) ReadLE(v interface{}) {
|
||||||
|
|
||||||
// ReadArray reads array into value which must be
|
// ReadArray reads array into value which must be
|
||||||
// a pointer to a slice.
|
// a pointer to a slice.
|
||||||
func (r *BinReader) ReadArray(t interface{}) {
|
func (r *BinReader) ReadArray(t interface{}, maxSize ...int) {
|
||||||
value := reflect.ValueOf(t)
|
value := reflect.ValueOf(t)
|
||||||
if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Slice {
|
if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Slice {
|
||||||
panic(value.Type().String() + " is not a pointer to a slice")
|
panic(value.Type().String() + " is not a pointer to a slice")
|
||||||
|
@ -55,7 +60,18 @@ func (r *BinReader) ReadArray(t interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l := int(r.ReadVarUint())
|
ms := maxArraySize
|
||||||
|
if len(maxSize) != 0 {
|
||||||
|
ms = maxSize[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
lu := r.ReadVarUint()
|
||||||
|
if lu > uint64(ms) {
|
||||||
|
r.Err = fmt.Errorf("array is too big (%d)", lu)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l := int(lu)
|
||||||
arr := reflect.MakeSlice(sliceType, l, l)
|
arr := reflect.MakeSlice(sliceType, l, l)
|
||||||
|
|
||||||
for i := 0; i < l; i++ {
|
for i := 0; i < l; i++ {
|
||||||
|
|
|
@ -266,6 +266,17 @@ func TestBinReader_ReadArray(t *testing.T) {
|
||||||
require.NoError(t, r.Err)
|
require.NoError(t, r.Err)
|
||||||
require.Equal(t, elems, arrVal)
|
require.Equal(t, elems, arrVal)
|
||||||
|
|
||||||
|
r = NewBinReaderFromBuf(data)
|
||||||
|
arrVal = []testSerializable{}
|
||||||
|
r.ReadArray(&arrVal, 3)
|
||||||
|
require.NoError(t, r.Err)
|
||||||
|
require.Equal(t, elems, arrVal)
|
||||||
|
|
||||||
|
r = NewBinReaderFromBuf(data)
|
||||||
|
arrVal = []testSerializable{}
|
||||||
|
r.ReadArray(&arrVal, 2)
|
||||||
|
require.Error(t, r.Err)
|
||||||
|
|
||||||
r = NewBinReaderFromBuf([]byte{0})
|
r = NewBinReaderFromBuf([]byte{0})
|
||||||
r.ReadArray(&arrVal)
|
r.ReadArray(&arrVal)
|
||||||
require.NoError(t, r.Err)
|
require.NoError(t, r.Err)
|
||||||
|
|
Loading…
Reference in a new issue