diff --git a/pkg/vm/stackitem/item.go b/pkg/vm/stackitem/item.go index 1b84f4e51..9d23dccf8 100644 --- a/pkg/vm/stackitem/item.go +++ b/pkg/vm/stackitem/item.go @@ -74,6 +74,7 @@ var ( errTooBigInteger = fmt.Errorf("%w: integer", ErrTooBig) errTooBigKey = fmt.Errorf("%w: map key", ErrTooBig) errTooBigSize = fmt.Errorf("%w: size", ErrTooBig) + errTooBigElements = fmt.Errorf("%w: many elements", ErrTooBig) ) // mkInvConversion creates conversion error with additional metadata (from and diff --git a/pkg/vm/stackitem/serialization.go b/pkg/vm/stackitem/serialization.go index a4271a4cc..c4a8afac8 100644 --- a/pkg/vm/stackitem/serialization.go +++ b/pkg/vm/stackitem/serialization.go @@ -9,6 +9,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/io" ) +// MaxDeserialized is the maximum number one deserialized item can contain +// (including itself). +const MaxDeserialized = 2048 + // ErrRecursive is returned on attempts to serialize some recursive stack item // (like array including an item with reference to the same array). var ErrRecursive = errors.New("recursive item") @@ -25,6 +29,13 @@ type serContext struct { seen map[Item]sliceNoPointer } +// deserContext is an internal deserialization context. +type deserContext struct { + *io.BinReader + allowInvalid bool + limit int +} + // Serialize encodes given Item into the byte slice. func Serialize(item Item) ([]byte, error) { sc := serContext{ @@ -179,21 +190,36 @@ func Deserialize(data []byte) (Item, error) { // as a function because Item itself is an interface. Caveat: always check // reader's error value before using the returned Item. func DecodeBinary(r *io.BinReader) Item { - return decodeBinary(r, false) + dc := deserContext{ + BinReader: r, + allowInvalid: false, + limit: MaxDeserialized, + } + return dc.decodeBinary() } // DecodeBinaryProtected is similar to DecodeBinary but allows Interop and // Invalid values to be present (making it symmetric to EncodeBinaryProtected). func DecodeBinaryProtected(r *io.BinReader) Item { - return decodeBinary(r, true) + dc := deserContext{ + BinReader: r, + allowInvalid: true, + limit: MaxDeserialized, + } + return dc.decodeBinary() } -func decodeBinary(r *io.BinReader, allowInvalid bool) Item { +func (r *deserContext) decodeBinary() Item { var t = Type(r.ReadB()) if r.Err != nil { return nil } + r.limit-- + if r.limit < 0 { + r.Err = errTooBigElements + return nil + } switch t { case ByteArrayT, BufferT: data := r.ReadVarBytes(MaxSize) @@ -216,7 +242,7 @@ func decodeBinary(r *io.BinReader, allowInvalid bool) Item { } arr := make([]Item, size) for i := 0; i < size; i++ { - arr[i] = decodeBinary(r, allowInvalid) + arr[i] = r.decodeBinary() } if t == ArrayT { @@ -231,8 +257,8 @@ func decodeBinary(r *io.BinReader, allowInvalid bool) Item { } m := NewMap() for i := 0; i < size; i++ { - key := decodeBinary(r, allowInvalid) - value := decodeBinary(r, allowInvalid) + key := r.decodeBinary() + value := r.decodeBinary() if r.Err != nil { break } @@ -242,12 +268,12 @@ func decodeBinary(r *io.BinReader, allowInvalid bool) Item { case AnyT: return Null{} case InteropT: - if allowInvalid { + if r.allowInvalid { return NewInterop(nil) } fallthrough default: - if t == InvalidT && allowInvalid { + if t == InvalidT && r.allowInvalid { return nil } r.Err = fmt.Errorf("%w: %v", ErrInvalidType, t) diff --git a/pkg/vm/stackitem/serialization_test.go b/pkg/vm/stackitem/serialization_test.go index e93a59e21..91da8bd6e 100644 --- a/pkg/vm/stackitem/serialization_test.go +++ b/pkg/vm/stackitem/serialization_test.go @@ -144,7 +144,7 @@ func TestSerialize(t *testing.T) { for i := 0; i < MaxArraySize; i++ { m.Add(Make(i), zeroByteArray) } - testSerialize(t, nil, m) + // testSerialize(t, nil, m) // It contains too many elements already, so ErrTooBig. m.Add(Make(100500), zeroByteArray) data, err := Serialize(m) @@ -154,6 +154,23 @@ func TestSerialize(t *testing.T) { }) } +func TestDeserializeTooManyElements(t *testing.T) { + item := Make(0) + for i := 0; i < MaxDeserialized-1; i++ { // 1 for zero inner element. + item = Make([]Item{item}) + } + data, err := Serialize(item) + require.NoError(t, err) + _, err = Deserialize(data) + require.NoError(t, err) + + item = Make([]Item{item}) + data, err = Serialize(item) + require.NoError(t, err) + _, err = Deserialize(data) + require.True(t, errors.Is(err, ErrTooBig), err) +} + func BenchmarkEncodeBinary(b *testing.B) { arr := getBigArray(15)