diff --git a/cli/smartcontract/manifest.go b/cli/smartcontract/manifest.go index 3964ae1a9..76c7fdc68 100644 --- a/cli/smartcontract/manifest.go +++ b/cli/smartcontract/manifest.go @@ -109,7 +109,7 @@ func readManifest(filename string, hash util.Uint160) (*manifest.Manifest, []byt if err != nil { return nil, nil, err } - if err := m.IsValid(hash); err != nil { + if err := m.IsValid(hash, true); err != nil { return nil, nil, fmt.Errorf("manifest is invalid: %w", err) } return m, manifestBytes, nil diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 37c9077ec..191bfb848 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -459,7 +459,7 @@ func CreateManifest(di *DebugInfo, o *Options) (*manifest.Manifest, error) { return m, fmt.Errorf("method %s is marked as safe but missing from manifest", name) } } - err = m.IsValid(util.Uint160{}) // Check as much as possible without hash. + err = m.IsValid(util.Uint160{}, true) // Check as much as possible without hash. if err != nil { return m, fmt.Errorf("manifest is invalid: %w", err) } diff --git a/pkg/core/blockchain_neotest_test.go b/pkg/core/blockchain_neotest_test.go index 877610cd7..32d5ec93f 100644 --- a/pkg/core/blockchain_neotest_test.go +++ b/pkg/core/blockchain_neotest_test.go @@ -219,7 +219,6 @@ func TestBlockchain_StartFromExistingDB(t *testing.T) { require.Error(t, err) require.True(t, strings.Contains(err.Error(), "can't init MPT at height"), err) }) - /* See #2801 t.Run("failed native Management initialisation", func(t *testing.T) { ps = newPS(t) @@ -234,9 +233,8 @@ func TestBlockchain_StartFromExistingDB(t *testing.T) { _, _, _, err := chain.NewMultiWithCustomConfigAndStoreNoCheck(t, customConfig, cache) require.Error(t, err) - require.True(t, strings.Contains(err.Error(), "can't init cache for Management native contract"), err) + require.True(t, strings.Contains(err.Error(), "can't init natives cache: failed to initialize cache for ContractManagement"), err) }) - */ t.Run("invalid native contract activation", func(t *testing.T) { ps = newPS(t) diff --git a/pkg/core/native/management.go b/pkg/core/native/management.go index e48609c1c..233513164 100644 --- a/pkg/core/native/management.go +++ b/pkg/core/native/management.go @@ -389,7 +389,7 @@ func (m *Management) Deploy(ic *interop.Context, sender util.Uint160, neff *nef. if err != nil { return nil, err } - err = manif.IsValid(h) + err = manif.IsValid(h, false) // do not check manifest size, the whole state.Contract will be checked later. if err != nil { return nil, fmt.Errorf("invalid manifest: %w", err) } @@ -458,7 +458,7 @@ func (m *Management) Update(ic *interop.Context, hash util.Uint160, neff *nef.Fi if manif.Name != contract.Manifest.Name { return nil, errors.New("contract name can't be changed") } - err = manif.IsValid(contract.Hash) + err = manif.IsValid(contract.Hash, false) // do not check manifest size, the whole state.Contract will be checked later. if err != nil { return nil, fmt.Errorf("invalid manifest: %w", err) } @@ -619,13 +619,19 @@ func (m *Management) InitializeCache(blockHeight uint32, d *dao.Simple) error { nep17: make(map[util.Uint160]struct{}), } - d.Seek(m.ID, storage.SeekRange{Prefix: []byte{PrefixContract}}, func(k, v []byte) bool { + var initErr error + d.Seek(m.ID, storage.SeekRange{Prefix: []byte{PrefixContract}}, func(_, v []byte) bool { var cs = new(state.Contract) - if stackitem.DeserializeConvertible(v, cs) == nil { - updateContractCache(cache, cs) + initErr = stackitem.DeserializeConvertible(v, cs) + if initErr != nil { + return false } + updateContractCache(cache, cs) return true }) + if initErr != nil { + return initErr + } d.SetCache(m.ID, cache) return nil } diff --git a/pkg/core/native/management_test.go b/pkg/core/native/management_test.go index 4184eb593..c8a13b4a5 100644 --- a/pkg/core/native/management_test.go +++ b/pkg/core/native/management_test.go @@ -85,14 +85,12 @@ func TestManagement_Initialize(t *testing.T) { mgmt := newManagement() require.NoError(t, mgmt.InitializeCache(0, d)) }) - /* See #2801 t.Run("invalid contract state", func(t *testing.T) { d := dao.NewSimple(storage.NewMemoryStore(), false) mgmt := newManagement() d.PutStorageItem(mgmt.ID, []byte{PrefixContract}, state.StorageItem{0xFF}) - require.Error(t, mgmt.InitializeCache(d)) + require.Error(t, mgmt.InitializeCache(0, d)) }) - */ } func TestManagement_GetNEP17Contracts(t *testing.T) { diff --git a/pkg/smartcontract/manifest/manifest.go b/pkg/smartcontract/manifest/manifest.go index 5c002381e..3f5ac8417 100644 --- a/pkg/smartcontract/manifest/manifest.go +++ b/pkg/smartcontract/manifest/manifest.go @@ -84,7 +84,7 @@ func (m *Manifest) CanCall(hash util.Uint160, toCall *Manifest, method string) b // IsValid checks manifest internal consistency and correctness, one of the // checks is for group signature correctness, contract hash is passed for it. // If hash is empty, then hash-related checks are omitted. -func (m *Manifest) IsValid(hash util.Uint160) error { +func (m *Manifest) IsValid(hash util.Uint160, checkSize bool) error { var err error if m.Name == "" { @@ -118,7 +118,22 @@ func (m *Manifest) IsValid(hash util.Uint160) error { return errors.New("duplicate trusted contracts") } } - return Permissions(m.Permissions).AreValid() + err = Permissions(m.Permissions).AreValid() + if err != nil { + return err + } + if !checkSize { + return nil + } + si, err := m.ToStackItem() + if err != nil { + return fmt.Errorf("failed to check manifest serialisation: %w", err) + } + _, err = stackitem.Serialize(si) + if err != nil { + return fmt.Errorf("manifest is not serializable: %w", err) + } + return nil } // IsStandardSupported denotes whether the specified standard is supported by the contract. diff --git a/pkg/smartcontract/manifest/manifest_test.go b/pkg/smartcontract/manifest/manifest_test.go index 2b9b20bae..7fd940f81 100644 --- a/pkg/smartcontract/manifest/manifest_test.go +++ b/pkg/smartcontract/manifest/manifest_test.go @@ -2,6 +2,7 @@ package manifest import ( "encoding/json" + "fmt" "math/big" "testing" @@ -124,13 +125,13 @@ func TestIsValid(t *testing.T) { m := &Manifest{} t.Run("invalid, no name", func(t *testing.T) { - require.Error(t, m.IsValid(contractHash)) + require.Error(t, m.IsValid(contractHash, true)) }) m = NewManifest("Test") t.Run("invalid, no ABI methods", func(t *testing.T) { - require.Error(t, m.IsValid(contractHash)) + require.Error(t, m.IsValid(contractHash, true)) }) m.ABI.Methods = append(m.ABI.Methods, Method{ @@ -140,7 +141,7 @@ func TestIsValid(t *testing.T) { }) t.Run("valid, no groups/events", func(t *testing.T) { - require.NoError(t, m.IsValid(contractHash)) + require.NoError(t, m.IsValid(contractHash, true)) }) m.ABI.Events = append(m.ABI.Events, Event{ @@ -149,7 +150,7 @@ func TestIsValid(t *testing.T) { }) t.Run("valid, with events", func(t *testing.T) { - require.NoError(t, m.IsValid(contractHash)) + require.NoError(t, m.IsValid(contractHash, true)) }) m.ABI.Events = append(m.ABI.Events, Event{ @@ -161,52 +162,52 @@ func TestIsValid(t *testing.T) { }) t.Run("invalid, bad event", func(t *testing.T) { - require.Error(t, m.IsValid(contractHash)) + require.Error(t, m.IsValid(contractHash, true)) }) m.ABI.Events = m.ABI.Events[:1] m.Permissions = append(m.Permissions, *NewPermission(PermissionHash, util.Uint160{1, 2, 3})) t.Run("valid, with permissions", func(t *testing.T) { - require.NoError(t, m.IsValid(contractHash)) + require.NoError(t, m.IsValid(contractHash, true)) }) m.Permissions = append(m.Permissions, *NewPermission(PermissionHash, util.Uint160{1, 2, 3})) t.Run("invalid, with permissions", func(t *testing.T) { - require.Error(t, m.IsValid(contractHash)) + require.Error(t, m.IsValid(contractHash, true)) }) m.Permissions = m.Permissions[:1] m.SupportedStandards = append(m.SupportedStandards, "NEP-17") t.Run("valid, with standards", func(t *testing.T) { - require.NoError(t, m.IsValid(contractHash)) + require.NoError(t, m.IsValid(contractHash, true)) }) m.SupportedStandards = append(m.SupportedStandards, "") t.Run("invalid, with nameless standard", func(t *testing.T) { - require.Error(t, m.IsValid(contractHash)) + require.Error(t, m.IsValid(contractHash, true)) }) m.SupportedStandards = m.SupportedStandards[:1] m.SupportedStandards = append(m.SupportedStandards, "NEP-17") t.Run("invalid, with duplicate standards", func(t *testing.T) { - require.Error(t, m.IsValid(contractHash)) + require.Error(t, m.IsValid(contractHash, true)) }) m.SupportedStandards = m.SupportedStandards[:1] d := PermissionDesc{Type: PermissionHash, Value: util.Uint160{1, 2, 3}} m.Trusts.Add(d) t.Run("valid, with trust", func(t *testing.T) { - require.NoError(t, m.IsValid(contractHash)) + require.NoError(t, m.IsValid(contractHash, true)) }) m.Trusts.Add(PermissionDesc{Type: PermissionHash, Value: util.Uint160{3, 2, 1}}) t.Run("valid, with trusts", func(t *testing.T) { - require.NoError(t, m.IsValid(contractHash)) + require.NoError(t, m.IsValid(contractHash, true)) }) m.Trusts.Add(d) t.Run("invalid, with trusts", func(t *testing.T) { - require.Error(t, m.IsValid(contractHash)) + require.Error(t, m.IsValid(contractHash, true)) }) m.Trusts.Restrict() @@ -224,11 +225,11 @@ func TestIsValid(t *testing.T) { } t.Run("valid", func(t *testing.T) { - require.NoError(t, m.IsValid(contractHash)) + require.NoError(t, m.IsValid(contractHash, true)) }) t.Run("invalid, wrong contract hash", func(t *testing.T) { - require.Error(t, m.IsValid(util.Uint160{4, 5, 6})) + require.Error(t, m.IsValid(util.Uint160{4, 5, 6}, true)) }) t.Run("invalid, wrong group signature", func(t *testing.T) { @@ -240,9 +241,21 @@ func TestIsValid(t *testing.T) { // of the contract hash. Signature: pk.Sign([]byte{1, 2, 3}), }) - require.Error(t, m.IsValid(contractHash)) + require.Error(t, m.IsValid(contractHash, true)) }) }) + m.Groups = m.Groups[:0] + + t.Run("invalid, unserializable", func(t *testing.T) { + for i := 0; i < stackitem.MaxSerialized; i++ { + m.ABI.Events = append(m.ABI.Events, Event{ + Name: fmt.Sprintf("Event%d", i), + Parameters: []Parameter{}, + }) + } + err := m.IsValid(contractHash, true) + require.ErrorIs(t, err, stackitem.ErrTooBig) + }) } func TestManifestToStackItem(t *testing.T) { diff --git a/pkg/vm/stackitem/json.go b/pkg/vm/stackitem/json.go index a2186ed1f..c18ef002d 100644 --- a/pkg/vm/stackitem/json.go +++ b/pkg/vm/stackitem/json.go @@ -65,9 +65,12 @@ func ToJSON(item Item) ([]byte, error) { } // sliceNoPointer represents a sub-slice of a known slice. -// It doesn't contain any pointer and uses less memory than `[]byte`. +// It doesn't contain any pointer and uses the same amount of memory as `[]byte`, +// but at the same type has additional information about the number of items in +// the stackitem (including the stackitem itself). type sliceNoPointer struct { start, end int + itemsCount int } func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error) { @@ -105,7 +108,7 @@ func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error } } data = append(data, ']') - seen[item] = sliceNoPointer{start, len(data)} + seen[item] = sliceNoPointer{start: start, end: len(data)} case *Map: data = append(data, '{') for i := range it.value { @@ -126,7 +129,7 @@ func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error } } data = append(data, '}') - seen[item] = sliceNoPointer{start, len(data)} + seen[item] = sliceNoPointer{start: start, end: len(data)} case *BigInteger: if it.Big().CmpAbs(big.NewInt(MaxAllowedInteger)) == 1 { return nil, fmt.Errorf("%w (MaxAllowedInteger)", ErrInvalidValue) @@ -420,7 +423,7 @@ func toJSONWithTypes(data []byte, item Item, seen map[Item]sliceNoPointer) ([]by data = append(data, '}') if isBuffer { - seen[item] = sliceNoPointer{start, len(data)} + seen[item] = sliceNoPointer{start: start, end: len(data)} } } else { if len(data)+2 > MaxSize { // also take care of '}' @@ -428,7 +431,7 @@ func toJSONWithTypes(data []byte, item Item, seen map[Item]sliceNoPointer) ([]by } data = append(data, ']', '}') - seen[item] = sliceNoPointer{start, len(data)} + seen[item] = sliceNoPointer{start: start, end: len(data)} } return data, nil } diff --git a/pkg/vm/stackitem/serialization.go b/pkg/vm/stackitem/serialization.go index baab7fcd6..35e4d0713 100644 --- a/pkg/vm/stackitem/serialization.go +++ b/pkg/vm/stackitem/serialization.go @@ -14,6 +14,10 @@ import ( // (including itself). const MaxDeserialized = 2048 +// MaxSerialized is the maximum number one serialized item can contain +// (including itself). +const MaxSerialized = MaxDeserialized + // typicalNumOfItems is the number of items covering most serialization needs. // It's a hint used for map creation, so it does not limit anything, it's just // a microoptimization to avoid excessive reallocations. Most of the serialized @@ -33,6 +37,7 @@ type SerializationContext struct { uv [9]byte data []byte allowInvalid bool + limit int seen map[Item]sliceNoPointer } @@ -45,10 +50,20 @@ type deserContext struct { // Serialize encodes the given Item into a byte slice. func Serialize(item Item) ([]byte, error) { + return SerializeLimited(item, MaxSerialized) +} + +// SerializeLimited encodes the given Item into a byte slice using custom +// limit to restrict the maximum serialized number of elements. +func SerializeLimited(item Item, limit int) ([]byte, error) { sc := SerializationContext{ allowInvalid: false, + limit: MaxSerialized, seen: make(map[Item]sliceNoPointer, typicalNumOfItems), } + if limit > 0 { + sc.limit = limit + } err := sc.serialize(item) if err != nil { return nil, err @@ -76,6 +91,7 @@ func EncodeBinary(item Item, w *io.BinWriter) { func EncodeBinaryProtected(item Item, w *io.BinWriter) { sc := SerializationContext{ allowInvalid: true, + limit: MaxSerialized, seen: make(map[Item]sliceNoPointer, typicalNumOfItems), } err := sc.serialize(item) @@ -88,28 +104,32 @@ func EncodeBinaryProtected(item Item, w *io.BinWriter) { func (w *SerializationContext) writeArray(item Item, arr []Item, start int) error { w.seen[item] = sliceNoPointer{} + limit := w.limit w.appendVarUint(uint64(len(arr))) for i := range arr { if err := w.serialize(arr[i]); err != nil { return err } } - w.seen[item] = sliceNoPointer{start, len(w.data)} + w.seen[item] = sliceNoPointer{start, len(w.data), limit - w.limit + 1} // number of items including the array itself. return nil } // NewSerializationContext returns reusable stack item serialization context. func NewSerializationContext() *SerializationContext { return &SerializationContext{ - seen: make(map[Item]sliceNoPointer, typicalNumOfItems), + limit: MaxSerialized, + seen: make(map[Item]sliceNoPointer, typicalNumOfItems), } } // Serialize returns flat slice of bytes with the given item. The process can be protected // from bad elements if appropriate flag is given (otherwise an error is returned on // encountering any of them). The buffer returned is only valid until the call to Serialize. +// The number of serialized items is restricted with MaxSerialized. func (w *SerializationContext) Serialize(item Item, protected bool) ([]byte, error) { w.allowInvalid = protected + w.limit = MaxSerialized if w.data != nil { w.data = w.data[:0] } @@ -135,10 +155,17 @@ func (w *SerializationContext) serialize(item Item) error { if len(w.data)+v.end-v.start > MaxSize { return ErrTooBig } + w.limit -= v.itemsCount + if w.limit < 0 { + return errTooBigElements + } w.data = append(w.data, w.data[v.start:v.end]...) return nil } - + w.limit-- + if w.limit < 0 { + return errTooBigElements + } start := len(w.data) switch t := item.(type) { case *ByteArray: @@ -188,6 +215,7 @@ func (w *SerializationContext) serialize(item Item) error { } case *Map: w.seen[item] = sliceNoPointer{} + limit := w.limit elems := t.value w.data = append(w.data, byte(MapT)) @@ -200,7 +228,7 @@ func (w *SerializationContext) serialize(item Item) error { return err } } - w.seen[item] = sliceNoPointer{start, len(w.data)} + w.seen[item] = sliceNoPointer{start, len(w.data), limit - w.limit + 1} // number of items including Map itself. case Null: w.data = append(w.data, byte(AnyT)) case nil: diff --git a/pkg/vm/stackitem/serialization_test.go b/pkg/vm/stackitem/serialization_test.go index 4722df88b..4c9fc3a92 100644 --- a/pkg/vm/stackitem/serialization_test.go +++ b/pkg/vm/stackitem/serialization_test.go @@ -1,6 +1,7 @@ package stackitem import ( + "strconv" "testing" "github.com/nspcc-dev/neo-go/pkg/io" @@ -23,7 +24,19 @@ func TestSerializationMaxErr(t *testing.T) { } func testSerialize(t *testing.T, expectedErr error, item Item) { - data, err := Serialize(item) + testSerializeLimited(t, expectedErr, item, -1) +} + +func testSerializeLimited(t *testing.T, expectedErr error, item Item, limit int) { + var ( + data []byte + err error + ) + if limit > 0 { + data, err = SerializeLimited(item, limit) + } else { + data, err = Serialize(item) + } if expectedErr != nil { require.ErrorIs(t, err, expectedErr) return @@ -58,7 +71,9 @@ func TestSerialize(t *testing.T) { testSerialize(t, nil, newItem(items)) items = append(items, zeroByteArray) - data, err := Serialize(newItem(items)) + _, err := Serialize(newItem(items)) + require.ErrorIs(t, err, errTooBigElements) + data, err := SerializeLimited(newItem(items), MaxSerialized+1) // a tiny hack to check deserialization error. require.NoError(t, err) _, err = Deserialize(data) require.ErrorIs(t, err, ErrTooBig) @@ -165,13 +180,70 @@ func TestSerialize(t *testing.T) { for i := 0; i <= MaxDeserialized; i++ { m.Add(Make(i), zeroByteArray) } - data, err := Serialize(m) + _, err := Serialize(m) + require.ErrorIs(t, err, errTooBigElements) + data, err := SerializeLimited(m, (MaxSerialized+1)*2+1) // a tiny hack to check deserialization error. require.NoError(t, err) _, err = Deserialize(data) require.ErrorIs(t, err, ErrTooBig) }) } +func TestSerializeLimited(t *testing.T) { + const customLimit = 10 + + smallArray := make([]Item, customLimit-1) + for i := range smallArray { + smallArray[i] = NewBool(true) + } + bigArray := make([]Item, customLimit) + for i := range bigArray { + bigArray[i] = NewBool(true) + } + t.Run("array", func(t *testing.T) { + testSerializeLimited(t, nil, NewArray(smallArray), customLimit) + testSerializeLimited(t, errTooBigElements, NewArray(bigArray), customLimit) + }) + t.Run("struct", func(t *testing.T) { + testSerializeLimited(t, nil, NewStruct(smallArray), customLimit) + testSerializeLimited(t, errTooBigElements, NewStruct(bigArray), customLimit) + }) + t.Run("map", func(t *testing.T) { + smallMap := make([]MapElement, (customLimit-1)/2) + for i := range smallMap { + smallMap[i] = MapElement{ + Key: NewByteArray([]byte(strconv.Itoa(i))), + Value: NewBool(true), + } + } + bigMap := make([]MapElement, customLimit/2) + for i := range bigMap { + bigMap[i] = MapElement{ + Key: NewByteArray([]byte("key")), + Value: NewBool(true), + } + } + testSerializeLimited(t, nil, NewMapWithValue(smallMap), customLimit) + testSerializeLimited(t, errTooBigElements, NewMapWithValue(bigMap), customLimit) + }) + t.Run("seen", func(t *testing.T) { + t.Run("OK", func(t *testing.T) { + tinyArray := NewArray(make([]Item, (customLimit-3)/2)) // 1 for outer array, 1+1 for two inner arrays and the rest are for arrays' elements. + for i := range tinyArray.value { + tinyArray.value[i] = NewBool(true) + } + testSerializeLimited(t, nil, NewArray([]Item{tinyArray, tinyArray}), customLimit) + }) + t.Run("big", func(t *testing.T) { + tinyArray := NewArray(make([]Item, (customLimit-2)/2)) // should break on the second array serialisation. + for i := range tinyArray.value { + tinyArray.value[i] = NewBool(true) + } + testSerializeLimited(t, errTooBigElements, NewArray([]Item{tinyArray, tinyArray}), customLimit) + }) + }) +} + func TestEmptyDeserialization(t *testing.T) { empty := []byte{} _, err := Deserialize(empty) @@ -202,7 +274,7 @@ func TestDeserializeTooManyElements(t *testing.T) { require.NoError(t, err) item = Make([]Item{item}) - data, err = Serialize(item) + data, err = SerializeLimited(item, MaxSerialized+1) // tiny hack to avoid serialization error. require.NoError(t, err) _, err = Deserialize(data) require.ErrorIs(t, err, ErrTooBig) @@ -214,14 +286,14 @@ func TestDeserializeLimited(t *testing.T) { for i := 0; i < customLimit-1; i++ { // 1 for zero inner element. item = Make([]Item{item}) } - data, err := Serialize(item) + data, err := SerializeLimited(item, customLimit) // tiny hack to avoid serialization error. require.NoError(t, err) actual, err := DeserializeLimited(data, customLimit) require.NoError(t, err) require.Equal(t, item, actual) item = Make([]Item{item}) - data, err = Serialize(item) + data, err = SerializeLimited(item, customLimit+1) // tiny hack to avoid serialization error. require.NoError(t, err) _, err = DeserializeLimited(data, customLimit) require.Error(t, err)