From 387c411da067f26584ea483a52d7e66dede1e7e8 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Wed, 22 Nov 2023 13:52:41 +0300 Subject: [PATCH] vm: add default limit to SI serialization context Follow the notion of https://github.com/neo-project/neo/pull/2948. Signed-off-by: Anna Shaleva --- pkg/vm/stackitem/json.go | 13 ++-- pkg/vm/stackitem/serialization.go | 36 +++++++++-- pkg/vm/stackitem/serialization_test.go | 84 ++++++++++++++++++++++++-- 3 files changed, 118 insertions(+), 15 deletions(-) diff --git a/pkg/vm/stackitem/json.go b/pkg/vm/stackitem/json.go index a2186ed1f..c18ef002d 100644 --- a/pkg/vm/stackitem/json.go +++ b/pkg/vm/stackitem/json.go @@ -65,9 +65,12 @@ func ToJSON(item Item) ([]byte, error) { } // sliceNoPointer represents a sub-slice of a known slice. -// It doesn't contain any pointer and uses less memory than `[]byte`. +// It doesn't contain any pointer and uses the same amount of memory as `[]byte`, +// but at the same type has additional information about the number of items in +// the stackitem (including the stackitem itself). type sliceNoPointer struct { start, end int + itemsCount int } func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error) { @@ -105,7 +108,7 @@ func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error } } data = append(data, ']') - seen[item] = sliceNoPointer{start, len(data)} + seen[item] = sliceNoPointer{start: start, end: len(data)} case *Map: data = append(data, '{') for i := range it.value { @@ -126,7 +129,7 @@ func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error } } data = append(data, '}') - seen[item] = sliceNoPointer{start, len(data)} + seen[item] = sliceNoPointer{start: start, end: len(data)} case *BigInteger: if it.Big().CmpAbs(big.NewInt(MaxAllowedInteger)) == 1 { return nil, fmt.Errorf("%w (MaxAllowedInteger)", ErrInvalidValue) @@ -420,7 +423,7 @@ func toJSONWithTypes(data []byte, item Item, seen map[Item]sliceNoPointer) ([]by data = append(data, '}') if isBuffer { - seen[item] = sliceNoPointer{start, len(data)} + seen[item] = sliceNoPointer{start: start, end: len(data)} } } else { if len(data)+2 > MaxSize { // also take care of '}' @@ -428,7 +431,7 @@ func toJSONWithTypes(data []byte, item Item, seen map[Item]sliceNoPointer) ([]by } data = append(data, ']', '}') - seen[item] = sliceNoPointer{start, len(data)} + seen[item] = sliceNoPointer{start: start, end: len(data)} } return data, nil } diff --git a/pkg/vm/stackitem/serialization.go b/pkg/vm/stackitem/serialization.go index baab7fcd6..35e4d0713 100644 --- a/pkg/vm/stackitem/serialization.go +++ b/pkg/vm/stackitem/serialization.go @@ -14,6 +14,10 @@ import ( // (including itself). const MaxDeserialized = 2048 +// MaxSerialized is the maximum number one serialized item can contain +// (including itself). +const MaxSerialized = MaxDeserialized + // typicalNumOfItems is the number of items covering most serialization needs. // It's a hint used for map creation, so it does not limit anything, it's just // a microoptimization to avoid excessive reallocations. Most of the serialized @@ -33,6 +37,7 @@ type SerializationContext struct { uv [9]byte data []byte allowInvalid bool + limit int seen map[Item]sliceNoPointer } @@ -45,10 +50,20 @@ type deserContext struct { // Serialize encodes the given Item into a byte slice. func Serialize(item Item) ([]byte, error) { + return SerializeLimited(item, MaxSerialized) +} + +// SerializeLimited encodes the given Item into a byte slice using custom +// limit to restrict the maximum serialized number of elements. +func SerializeLimited(item Item, limit int) ([]byte, error) { sc := SerializationContext{ allowInvalid: false, + limit: MaxSerialized, seen: make(map[Item]sliceNoPointer, typicalNumOfItems), } + if limit > 0 { + sc.limit = limit + } err := sc.serialize(item) if err != nil { return nil, err @@ -76,6 +91,7 @@ func EncodeBinary(item Item, w *io.BinWriter) { func EncodeBinaryProtected(item Item, w *io.BinWriter) { sc := SerializationContext{ allowInvalid: true, + limit: MaxSerialized, seen: make(map[Item]sliceNoPointer, typicalNumOfItems), } err := sc.serialize(item) @@ -88,28 +104,32 @@ func EncodeBinaryProtected(item Item, w *io.BinWriter) { func (w *SerializationContext) writeArray(item Item, arr []Item, start int) error { w.seen[item] = sliceNoPointer{} + limit := w.limit w.appendVarUint(uint64(len(arr))) for i := range arr { if err := w.serialize(arr[i]); err != nil { return err } } - w.seen[item] = sliceNoPointer{start, len(w.data)} + w.seen[item] = sliceNoPointer{start, len(w.data), limit - w.limit + 1} // number of items including the array itself. return nil } // NewSerializationContext returns reusable stack item serialization context. func NewSerializationContext() *SerializationContext { return &SerializationContext{ - seen: make(map[Item]sliceNoPointer, typicalNumOfItems), + limit: MaxSerialized, + seen: make(map[Item]sliceNoPointer, typicalNumOfItems), } } // Serialize returns flat slice of bytes with the given item. The process can be protected // from bad elements if appropriate flag is given (otherwise an error is returned on // encountering any of them). The buffer returned is only valid until the call to Serialize. +// The number of serialized items is restricted with MaxSerialized. func (w *SerializationContext) Serialize(item Item, protected bool) ([]byte, error) { w.allowInvalid = protected + w.limit = MaxSerialized if w.data != nil { w.data = w.data[:0] } @@ -135,10 +155,17 @@ func (w *SerializationContext) serialize(item Item) error { if len(w.data)+v.end-v.start > MaxSize { return ErrTooBig } + w.limit -= v.itemsCount + if w.limit < 0 { + return errTooBigElements + } w.data = append(w.data, w.data[v.start:v.end]...) return nil } - + w.limit-- + if w.limit < 0 { + return errTooBigElements + } start := len(w.data) switch t := item.(type) { case *ByteArray: @@ -188,6 +215,7 @@ func (w *SerializationContext) serialize(item Item) error { } case *Map: w.seen[item] = sliceNoPointer{} + limit := w.limit elems := t.value w.data = append(w.data, byte(MapT)) @@ -200,7 +228,7 @@ func (w *SerializationContext) serialize(item Item) error { return err } } - w.seen[item] = sliceNoPointer{start, len(w.data)} + w.seen[item] = sliceNoPointer{start, len(w.data), limit - w.limit + 1} // number of items including Map itself. case Null: w.data = append(w.data, byte(AnyT)) case nil: diff --git a/pkg/vm/stackitem/serialization_test.go b/pkg/vm/stackitem/serialization_test.go index 4722df88b..4c9fc3a92 100644 --- a/pkg/vm/stackitem/serialization_test.go +++ b/pkg/vm/stackitem/serialization_test.go @@ -1,6 +1,7 @@ package stackitem import ( + "strconv" "testing" "github.com/nspcc-dev/neo-go/pkg/io" @@ -23,7 +24,19 @@ func TestSerializationMaxErr(t *testing.T) { } func testSerialize(t *testing.T, expectedErr error, item Item) { - data, err := Serialize(item) + testSerializeLimited(t, expectedErr, item, -1) +} + +func testSerializeLimited(t *testing.T, expectedErr error, item Item, limit int) { + var ( + data []byte + err error + ) + if limit > 0 { + data, err = SerializeLimited(item, limit) + } else { + data, err = Serialize(item) + } if expectedErr != nil { require.ErrorIs(t, err, expectedErr) return @@ -58,7 +71,9 @@ func TestSerialize(t *testing.T) { testSerialize(t, nil, newItem(items)) items = append(items, zeroByteArray) - data, err := Serialize(newItem(items)) + _, err := Serialize(newItem(items)) + require.ErrorIs(t, err, errTooBigElements) + data, err := SerializeLimited(newItem(items), MaxSerialized+1) // a tiny hack to check deserialization error. require.NoError(t, err) _, err = Deserialize(data) require.ErrorIs(t, err, ErrTooBig) @@ -165,13 +180,70 @@ func TestSerialize(t *testing.T) { for i := 0; i <= MaxDeserialized; i++ { m.Add(Make(i), zeroByteArray) } - data, err := Serialize(m) + _, err := Serialize(m) + require.ErrorIs(t, err, errTooBigElements) + data, err := SerializeLimited(m, (MaxSerialized+1)*2+1) // a tiny hack to check deserialization error. require.NoError(t, err) _, err = Deserialize(data) require.ErrorIs(t, err, ErrTooBig) }) } +func TestSerializeLimited(t *testing.T) { + const customLimit = 10 + + smallArray := make([]Item, customLimit-1) + for i := range smallArray { + smallArray[i] = NewBool(true) + } + bigArray := make([]Item, customLimit) + for i := range bigArray { + bigArray[i] = NewBool(true) + } + t.Run("array", func(t *testing.T) { + testSerializeLimited(t, nil, NewArray(smallArray), customLimit) + testSerializeLimited(t, errTooBigElements, NewArray(bigArray), customLimit) + }) + t.Run("struct", func(t *testing.T) { + testSerializeLimited(t, nil, NewStruct(smallArray), customLimit) + testSerializeLimited(t, errTooBigElements, NewStruct(bigArray), customLimit) + }) + t.Run("map", func(t *testing.T) { + smallMap := make([]MapElement, (customLimit-1)/2) + for i := range smallMap { + smallMap[i] = MapElement{ + Key: NewByteArray([]byte(strconv.Itoa(i))), + Value: NewBool(true), + } + } + bigMap := make([]MapElement, customLimit/2) + for i := range bigMap { + bigMap[i] = MapElement{ + Key: NewByteArray([]byte("key")), + Value: NewBool(true), + } + } + testSerializeLimited(t, nil, NewMapWithValue(smallMap), customLimit) + testSerializeLimited(t, errTooBigElements, NewMapWithValue(bigMap), customLimit) + }) + t.Run("seen", func(t *testing.T) { + t.Run("OK", func(t *testing.T) { + tinyArray := NewArray(make([]Item, (customLimit-3)/2)) // 1 for outer array, 1+1 for two inner arrays and the rest are for arrays' elements. + for i := range tinyArray.value { + tinyArray.value[i] = NewBool(true) + } + testSerializeLimited(t, nil, NewArray([]Item{tinyArray, tinyArray}), customLimit) + }) + t.Run("big", func(t *testing.T) { + tinyArray := NewArray(make([]Item, (customLimit-2)/2)) // should break on the second array serialisation. + for i := range tinyArray.value { + tinyArray.value[i] = NewBool(true) + } + testSerializeLimited(t, errTooBigElements, NewArray([]Item{tinyArray, tinyArray}), customLimit) + }) + }) +} + func TestEmptyDeserialization(t *testing.T) { empty := []byte{} _, err := Deserialize(empty) @@ -202,7 +274,7 @@ func TestDeserializeTooManyElements(t *testing.T) { require.NoError(t, err) item = Make([]Item{item}) - data, err = Serialize(item) + data, err = SerializeLimited(item, MaxSerialized+1) // tiny hack to avoid serialization error. require.NoError(t, err) _, err = Deserialize(data) require.ErrorIs(t, err, ErrTooBig) @@ -214,14 +286,14 @@ func TestDeserializeLimited(t *testing.T) { for i := 0; i < customLimit-1; i++ { // 1 for zero inner element. item = Make([]Item{item}) } - data, err := Serialize(item) + data, err := SerializeLimited(item, customLimit) // tiny hack to avoid serialization error. require.NoError(t, err) actual, err := DeserializeLimited(data, customLimit) require.NoError(t, err) require.Equal(t, item, actual) item = Make([]Item{item}) - data, err = Serialize(item) + data, err = SerializeLimited(item, customLimit+1) // tiny hack to avoid serialization error. require.NoError(t, err) _, err = DeserializeLimited(data, customLimit) require.Error(t, err)