mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-25 03:47:18 +00:00
Merge pull request #3218 from nspcc-dev/serialization-limits
Introduce stackitem serialization limits
This commit is contained in:
commit
25ef2c7f16
10 changed files with 179 additions and 46 deletions
|
@ -109,7 +109,7 @@ func readManifest(filename string, hash util.Uint160) (*manifest.Manifest, []byt
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if err := m.IsValid(hash); err != nil {
|
if err := m.IsValid(hash, true); err != nil {
|
||||||
return nil, nil, fmt.Errorf("manifest is invalid: %w", err)
|
return nil, nil, fmt.Errorf("manifest is invalid: %w", err)
|
||||||
}
|
}
|
||||||
return m, manifestBytes, nil
|
return m, manifestBytes, nil
|
||||||
|
|
|
@ -459,7 +459,7 @@ func CreateManifest(di *DebugInfo, o *Options) (*manifest.Manifest, error) {
|
||||||
return m, fmt.Errorf("method %s is marked as safe but missing from manifest", name)
|
return m, fmt.Errorf("method %s is marked as safe but missing from manifest", name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = m.IsValid(util.Uint160{}) // Check as much as possible without hash.
|
err = m.IsValid(util.Uint160{}, true) // Check as much as possible without hash.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return m, fmt.Errorf("manifest is invalid: %w", err)
|
return m, fmt.Errorf("manifest is invalid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -219,7 +219,6 @@ func TestBlockchain_StartFromExistingDB(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.True(t, strings.Contains(err.Error(), "can't init MPT at height"), err)
|
require.True(t, strings.Contains(err.Error(), "can't init MPT at height"), err)
|
||||||
})
|
})
|
||||||
/* See #2801
|
|
||||||
t.Run("failed native Management initialisation", func(t *testing.T) {
|
t.Run("failed native Management initialisation", func(t *testing.T) {
|
||||||
ps = newPS(t)
|
ps = newPS(t)
|
||||||
|
|
||||||
|
@ -234,9 +233,8 @@ func TestBlockchain_StartFromExistingDB(t *testing.T) {
|
||||||
|
|
||||||
_, _, _, err := chain.NewMultiWithCustomConfigAndStoreNoCheck(t, customConfig, cache)
|
_, _, _, err := chain.NewMultiWithCustomConfigAndStoreNoCheck(t, customConfig, cache)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.True(t, strings.Contains(err.Error(), "can't init cache for Management native contract"), err)
|
require.True(t, strings.Contains(err.Error(), "can't init natives cache: failed to initialize cache for ContractManagement"), err)
|
||||||
})
|
})
|
||||||
*/
|
|
||||||
t.Run("invalid native contract activation", func(t *testing.T) {
|
t.Run("invalid native contract activation", func(t *testing.T) {
|
||||||
ps = newPS(t)
|
ps = newPS(t)
|
||||||
|
|
||||||
|
|
|
@ -389,7 +389,7 @@ func (m *Management) Deploy(ic *interop.Context, sender util.Uint160, neff *nef.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = manif.IsValid(h)
|
err = manif.IsValid(h, false) // do not check manifest size, the whole state.Contract will be checked later.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid manifest: %w", err)
|
return nil, fmt.Errorf("invalid manifest: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -458,7 +458,7 @@ func (m *Management) Update(ic *interop.Context, hash util.Uint160, neff *nef.Fi
|
||||||
if manif.Name != contract.Manifest.Name {
|
if manif.Name != contract.Manifest.Name {
|
||||||
return nil, errors.New("contract name can't be changed")
|
return nil, errors.New("contract name can't be changed")
|
||||||
}
|
}
|
||||||
err = manif.IsValid(contract.Hash)
|
err = manif.IsValid(contract.Hash, false) // do not check manifest size, the whole state.Contract will be checked later.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid manifest: %w", err)
|
return nil, fmt.Errorf("invalid manifest: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -619,13 +619,19 @@ func (m *Management) InitializeCache(blockHeight uint32, d *dao.Simple) error {
|
||||||
nep17: make(map[util.Uint160]struct{}),
|
nep17: make(map[util.Uint160]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
d.Seek(m.ID, storage.SeekRange{Prefix: []byte{PrefixContract}}, func(k, v []byte) bool {
|
var initErr error
|
||||||
|
d.Seek(m.ID, storage.SeekRange{Prefix: []byte{PrefixContract}}, func(_, v []byte) bool {
|
||||||
var cs = new(state.Contract)
|
var cs = new(state.Contract)
|
||||||
if stackitem.DeserializeConvertible(v, cs) == nil {
|
initErr = stackitem.DeserializeConvertible(v, cs)
|
||||||
updateContractCache(cache, cs)
|
if initErr != nil {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
updateContractCache(cache, cs)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
if initErr != nil {
|
||||||
|
return initErr
|
||||||
|
}
|
||||||
d.SetCache(m.ID, cache)
|
d.SetCache(m.ID, cache)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,14 +85,12 @@ func TestManagement_Initialize(t *testing.T) {
|
||||||
mgmt := newManagement()
|
mgmt := newManagement()
|
||||||
require.NoError(t, mgmt.InitializeCache(0, d))
|
require.NoError(t, mgmt.InitializeCache(0, d))
|
||||||
})
|
})
|
||||||
/* See #2801
|
|
||||||
t.Run("invalid contract state", func(t *testing.T) {
|
t.Run("invalid contract state", func(t *testing.T) {
|
||||||
d := dao.NewSimple(storage.NewMemoryStore(), false)
|
d := dao.NewSimple(storage.NewMemoryStore(), false)
|
||||||
mgmt := newManagement()
|
mgmt := newManagement()
|
||||||
d.PutStorageItem(mgmt.ID, []byte{PrefixContract}, state.StorageItem{0xFF})
|
d.PutStorageItem(mgmt.ID, []byte{PrefixContract}, state.StorageItem{0xFF})
|
||||||
require.Error(t, mgmt.InitializeCache(d))
|
require.Error(t, mgmt.InitializeCache(0, d))
|
||||||
})
|
})
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManagement_GetNEP17Contracts(t *testing.T) {
|
func TestManagement_GetNEP17Contracts(t *testing.T) {
|
||||||
|
|
|
@ -84,7 +84,7 @@ func (m *Manifest) CanCall(hash util.Uint160, toCall *Manifest, method string) b
|
||||||
// IsValid checks manifest internal consistency and correctness, one of the
|
// IsValid checks manifest internal consistency and correctness, one of the
|
||||||
// checks is for group signature correctness, contract hash is passed for it.
|
// checks is for group signature correctness, contract hash is passed for it.
|
||||||
// If hash is empty, then hash-related checks are omitted.
|
// If hash is empty, then hash-related checks are omitted.
|
||||||
func (m *Manifest) IsValid(hash util.Uint160) error {
|
func (m *Manifest) IsValid(hash util.Uint160, checkSize bool) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if m.Name == "" {
|
if m.Name == "" {
|
||||||
|
@ -118,7 +118,22 @@ func (m *Manifest) IsValid(hash util.Uint160) error {
|
||||||
return errors.New("duplicate trusted contracts")
|
return errors.New("duplicate trusted contracts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Permissions(m.Permissions).AreValid()
|
err = Permissions(m.Permissions).AreValid()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !checkSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
si, err := m.ToStackItem()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check manifest serialisation: %w", err)
|
||||||
|
}
|
||||||
|
_, err = stackitem.Serialize(si)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("manifest is not serializable: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsStandardSupported denotes whether the specified standard is supported by the contract.
|
// IsStandardSupported denotes whether the specified standard is supported by the contract.
|
||||||
|
|
|
@ -2,6 +2,7 @@ package manifest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -124,13 +125,13 @@ func TestIsValid(t *testing.T) {
|
||||||
m := &Manifest{}
|
m := &Manifest{}
|
||||||
|
|
||||||
t.Run("invalid, no name", func(t *testing.T) {
|
t.Run("invalid, no name", func(t *testing.T) {
|
||||||
require.Error(t, m.IsValid(contractHash))
|
require.Error(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
m = NewManifest("Test")
|
m = NewManifest("Test")
|
||||||
|
|
||||||
t.Run("invalid, no ABI methods", func(t *testing.T) {
|
t.Run("invalid, no ABI methods", func(t *testing.T) {
|
||||||
require.Error(t, m.IsValid(contractHash))
|
require.Error(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
m.ABI.Methods = append(m.ABI.Methods, Method{
|
m.ABI.Methods = append(m.ABI.Methods, Method{
|
||||||
|
@ -140,7 +141,7 @@ func TestIsValid(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("valid, no groups/events", func(t *testing.T) {
|
t.Run("valid, no groups/events", func(t *testing.T) {
|
||||||
require.NoError(t, m.IsValid(contractHash))
|
require.NoError(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
m.ABI.Events = append(m.ABI.Events, Event{
|
m.ABI.Events = append(m.ABI.Events, Event{
|
||||||
|
@ -149,7 +150,7 @@ func TestIsValid(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("valid, with events", func(t *testing.T) {
|
t.Run("valid, with events", func(t *testing.T) {
|
||||||
require.NoError(t, m.IsValid(contractHash))
|
require.NoError(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
m.ABI.Events = append(m.ABI.Events, Event{
|
m.ABI.Events = append(m.ABI.Events, Event{
|
||||||
|
@ -161,52 +162,52 @@ func TestIsValid(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid, bad event", func(t *testing.T) {
|
t.Run("invalid, bad event", func(t *testing.T) {
|
||||||
require.Error(t, m.IsValid(contractHash))
|
require.Error(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
m.ABI.Events = m.ABI.Events[:1]
|
m.ABI.Events = m.ABI.Events[:1]
|
||||||
|
|
||||||
m.Permissions = append(m.Permissions, *NewPermission(PermissionHash, util.Uint160{1, 2, 3}))
|
m.Permissions = append(m.Permissions, *NewPermission(PermissionHash, util.Uint160{1, 2, 3}))
|
||||||
t.Run("valid, with permissions", func(t *testing.T) {
|
t.Run("valid, with permissions", func(t *testing.T) {
|
||||||
require.NoError(t, m.IsValid(contractHash))
|
require.NoError(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
m.Permissions = append(m.Permissions, *NewPermission(PermissionHash, util.Uint160{1, 2, 3}))
|
m.Permissions = append(m.Permissions, *NewPermission(PermissionHash, util.Uint160{1, 2, 3}))
|
||||||
t.Run("invalid, with permissions", func(t *testing.T) {
|
t.Run("invalid, with permissions", func(t *testing.T) {
|
||||||
require.Error(t, m.IsValid(contractHash))
|
require.Error(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
m.Permissions = m.Permissions[:1]
|
m.Permissions = m.Permissions[:1]
|
||||||
|
|
||||||
m.SupportedStandards = append(m.SupportedStandards, "NEP-17")
|
m.SupportedStandards = append(m.SupportedStandards, "NEP-17")
|
||||||
t.Run("valid, with standards", func(t *testing.T) {
|
t.Run("valid, with standards", func(t *testing.T) {
|
||||||
require.NoError(t, m.IsValid(contractHash))
|
require.NoError(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
m.SupportedStandards = append(m.SupportedStandards, "")
|
m.SupportedStandards = append(m.SupportedStandards, "")
|
||||||
t.Run("invalid, with nameless standard", func(t *testing.T) {
|
t.Run("invalid, with nameless standard", func(t *testing.T) {
|
||||||
require.Error(t, m.IsValid(contractHash))
|
require.Error(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
m.SupportedStandards = m.SupportedStandards[:1]
|
m.SupportedStandards = m.SupportedStandards[:1]
|
||||||
|
|
||||||
m.SupportedStandards = append(m.SupportedStandards, "NEP-17")
|
m.SupportedStandards = append(m.SupportedStandards, "NEP-17")
|
||||||
t.Run("invalid, with duplicate standards", func(t *testing.T) {
|
t.Run("invalid, with duplicate standards", func(t *testing.T) {
|
||||||
require.Error(t, m.IsValid(contractHash))
|
require.Error(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
m.SupportedStandards = m.SupportedStandards[:1]
|
m.SupportedStandards = m.SupportedStandards[:1]
|
||||||
|
|
||||||
d := PermissionDesc{Type: PermissionHash, Value: util.Uint160{1, 2, 3}}
|
d := PermissionDesc{Type: PermissionHash, Value: util.Uint160{1, 2, 3}}
|
||||||
m.Trusts.Add(d)
|
m.Trusts.Add(d)
|
||||||
t.Run("valid, with trust", func(t *testing.T) {
|
t.Run("valid, with trust", func(t *testing.T) {
|
||||||
require.NoError(t, m.IsValid(contractHash))
|
require.NoError(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
m.Trusts.Add(PermissionDesc{Type: PermissionHash, Value: util.Uint160{3, 2, 1}})
|
m.Trusts.Add(PermissionDesc{Type: PermissionHash, Value: util.Uint160{3, 2, 1}})
|
||||||
t.Run("valid, with trusts", func(t *testing.T) {
|
t.Run("valid, with trusts", func(t *testing.T) {
|
||||||
require.NoError(t, m.IsValid(contractHash))
|
require.NoError(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
m.Trusts.Add(d)
|
m.Trusts.Add(d)
|
||||||
t.Run("invalid, with trusts", func(t *testing.T) {
|
t.Run("invalid, with trusts", func(t *testing.T) {
|
||||||
require.Error(t, m.IsValid(contractHash))
|
require.Error(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
m.Trusts.Restrict()
|
m.Trusts.Restrict()
|
||||||
|
|
||||||
|
@ -224,11 +225,11 @@ func TestIsValid(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
require.NoError(t, m.IsValid(contractHash))
|
require.NoError(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid, wrong contract hash", func(t *testing.T) {
|
t.Run("invalid, wrong contract hash", func(t *testing.T) {
|
||||||
require.Error(t, m.IsValid(util.Uint160{4, 5, 6}))
|
require.Error(t, m.IsValid(util.Uint160{4, 5, 6}, true))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid, wrong group signature", func(t *testing.T) {
|
t.Run("invalid, wrong group signature", func(t *testing.T) {
|
||||||
|
@ -240,9 +241,21 @@ func TestIsValid(t *testing.T) {
|
||||||
// of the contract hash.
|
// of the contract hash.
|
||||||
Signature: pk.Sign([]byte{1, 2, 3}),
|
Signature: pk.Sign([]byte{1, 2, 3}),
|
||||||
})
|
})
|
||||||
require.Error(t, m.IsValid(contractHash))
|
require.Error(t, m.IsValid(contractHash, true))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
m.Groups = m.Groups[:0]
|
||||||
|
|
||||||
|
t.Run("invalid, unserializable", func(t *testing.T) {
|
||||||
|
for i := 0; i < stackitem.MaxSerialized; i++ {
|
||||||
|
m.ABI.Events = append(m.ABI.Events, Event{
|
||||||
|
Name: fmt.Sprintf("Event%d", i),
|
||||||
|
Parameters: []Parameter{},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
err := m.IsValid(contractHash, true)
|
||||||
|
require.ErrorIs(t, err, stackitem.ErrTooBig)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManifestToStackItem(t *testing.T) {
|
func TestManifestToStackItem(t *testing.T) {
|
||||||
|
|
|
@ -65,9 +65,12 @@ func ToJSON(item Item) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// sliceNoPointer represents a sub-slice of a known slice.
|
// sliceNoPointer represents a sub-slice of a known slice.
|
||||||
// It doesn't contain any pointer and uses less memory than `[]byte`.
|
// It doesn't contain any pointer and uses the same amount of memory as `[]byte`,
|
||||||
|
// but at the same type has additional information about the number of items in
|
||||||
|
// the stackitem (including the stackitem itself).
|
||||||
type sliceNoPointer struct {
|
type sliceNoPointer struct {
|
||||||
start, end int
|
start, end int
|
||||||
|
itemsCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error) {
|
func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error) {
|
||||||
|
@ -105,7 +108,7 @@ func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data = append(data, ']')
|
data = append(data, ']')
|
||||||
seen[item] = sliceNoPointer{start, len(data)}
|
seen[item] = sliceNoPointer{start: start, end: len(data)}
|
||||||
case *Map:
|
case *Map:
|
||||||
data = append(data, '{')
|
data = append(data, '{')
|
||||||
for i := range it.value {
|
for i := range it.value {
|
||||||
|
@ -126,7 +129,7 @@ func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data = append(data, '}')
|
data = append(data, '}')
|
||||||
seen[item] = sliceNoPointer{start, len(data)}
|
seen[item] = sliceNoPointer{start: start, end: len(data)}
|
||||||
case *BigInteger:
|
case *BigInteger:
|
||||||
if it.Big().CmpAbs(big.NewInt(MaxAllowedInteger)) == 1 {
|
if it.Big().CmpAbs(big.NewInt(MaxAllowedInteger)) == 1 {
|
||||||
return nil, fmt.Errorf("%w (MaxAllowedInteger)", ErrInvalidValue)
|
return nil, fmt.Errorf("%w (MaxAllowedInteger)", ErrInvalidValue)
|
||||||
|
@ -420,7 +423,7 @@ func toJSONWithTypes(data []byte, item Item, seen map[Item]sliceNoPointer) ([]by
|
||||||
data = append(data, '}')
|
data = append(data, '}')
|
||||||
|
|
||||||
if isBuffer {
|
if isBuffer {
|
||||||
seen[item] = sliceNoPointer{start, len(data)}
|
seen[item] = sliceNoPointer{start: start, end: len(data)}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if len(data)+2 > MaxSize { // also take care of '}'
|
if len(data)+2 > MaxSize { // also take care of '}'
|
||||||
|
@ -428,7 +431,7 @@ func toJSONWithTypes(data []byte, item Item, seen map[Item]sliceNoPointer) ([]by
|
||||||
}
|
}
|
||||||
data = append(data, ']', '}')
|
data = append(data, ']', '}')
|
||||||
|
|
||||||
seen[item] = sliceNoPointer{start, len(data)}
|
seen[item] = sliceNoPointer{start: start, end: len(data)}
|
||||||
}
|
}
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,10 @@ import (
|
||||||
// (including itself).
|
// (including itself).
|
||||||
const MaxDeserialized = 2048
|
const MaxDeserialized = 2048
|
||||||
|
|
||||||
|
// MaxSerialized is the maximum number one serialized item can contain
|
||||||
|
// (including itself).
|
||||||
|
const MaxSerialized = MaxDeserialized
|
||||||
|
|
||||||
// typicalNumOfItems is the number of items covering most serialization needs.
|
// typicalNumOfItems is the number of items covering most serialization needs.
|
||||||
// It's a hint used for map creation, so it does not limit anything, it's just
|
// It's a hint used for map creation, so it does not limit anything, it's just
|
||||||
// a microoptimization to avoid excessive reallocations. Most of the serialized
|
// a microoptimization to avoid excessive reallocations. Most of the serialized
|
||||||
|
@ -33,6 +37,7 @@ type SerializationContext struct {
|
||||||
uv [9]byte
|
uv [9]byte
|
||||||
data []byte
|
data []byte
|
||||||
allowInvalid bool
|
allowInvalid bool
|
||||||
|
limit int
|
||||||
seen map[Item]sliceNoPointer
|
seen map[Item]sliceNoPointer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,10 +50,20 @@ type deserContext struct {
|
||||||
|
|
||||||
// Serialize encodes the given Item into a byte slice.
|
// Serialize encodes the given Item into a byte slice.
|
||||||
func Serialize(item Item) ([]byte, error) {
|
func Serialize(item Item) ([]byte, error) {
|
||||||
|
return SerializeLimited(item, MaxSerialized)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerializeLimited encodes the given Item into a byte slice using custom
|
||||||
|
// limit to restrict the maximum serialized number of elements.
|
||||||
|
func SerializeLimited(item Item, limit int) ([]byte, error) {
|
||||||
sc := SerializationContext{
|
sc := SerializationContext{
|
||||||
allowInvalid: false,
|
allowInvalid: false,
|
||||||
|
limit: MaxSerialized,
|
||||||
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
|
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
|
||||||
}
|
}
|
||||||
|
if limit > 0 {
|
||||||
|
sc.limit = limit
|
||||||
|
}
|
||||||
err := sc.serialize(item)
|
err := sc.serialize(item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -76,6 +91,7 @@ func EncodeBinary(item Item, w *io.BinWriter) {
|
||||||
func EncodeBinaryProtected(item Item, w *io.BinWriter) {
|
func EncodeBinaryProtected(item Item, w *io.BinWriter) {
|
||||||
sc := SerializationContext{
|
sc := SerializationContext{
|
||||||
allowInvalid: true,
|
allowInvalid: true,
|
||||||
|
limit: MaxSerialized,
|
||||||
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
|
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
|
||||||
}
|
}
|
||||||
err := sc.serialize(item)
|
err := sc.serialize(item)
|
||||||
|
@ -88,19 +104,21 @@ func EncodeBinaryProtected(item Item, w *io.BinWriter) {
|
||||||
|
|
||||||
func (w *SerializationContext) writeArray(item Item, arr []Item, start int) error {
|
func (w *SerializationContext) writeArray(item Item, arr []Item, start int) error {
|
||||||
w.seen[item] = sliceNoPointer{}
|
w.seen[item] = sliceNoPointer{}
|
||||||
|
limit := w.limit
|
||||||
w.appendVarUint(uint64(len(arr)))
|
w.appendVarUint(uint64(len(arr)))
|
||||||
for i := range arr {
|
for i := range arr {
|
||||||
if err := w.serialize(arr[i]); err != nil {
|
if err := w.serialize(arr[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.seen[item] = sliceNoPointer{start, len(w.data)}
|
w.seen[item] = sliceNoPointer{start, len(w.data), limit - w.limit + 1} // number of items including the array itself.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSerializationContext returns reusable stack item serialization context.
|
// NewSerializationContext returns reusable stack item serialization context.
|
||||||
func NewSerializationContext() *SerializationContext {
|
func NewSerializationContext() *SerializationContext {
|
||||||
return &SerializationContext{
|
return &SerializationContext{
|
||||||
|
limit: MaxSerialized,
|
||||||
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
|
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -108,8 +126,10 @@ func NewSerializationContext() *SerializationContext {
|
||||||
// Serialize returns flat slice of bytes with the given item. The process can be protected
|
// Serialize returns flat slice of bytes with the given item. The process can be protected
|
||||||
// from bad elements if appropriate flag is given (otherwise an error is returned on
|
// from bad elements if appropriate flag is given (otherwise an error is returned on
|
||||||
// encountering any of them). The buffer returned is only valid until the call to Serialize.
|
// encountering any of them). The buffer returned is only valid until the call to Serialize.
|
||||||
|
// The number of serialized items is restricted with MaxSerialized.
|
||||||
func (w *SerializationContext) Serialize(item Item, protected bool) ([]byte, error) {
|
func (w *SerializationContext) Serialize(item Item, protected bool) ([]byte, error) {
|
||||||
w.allowInvalid = protected
|
w.allowInvalid = protected
|
||||||
|
w.limit = MaxSerialized
|
||||||
if w.data != nil {
|
if w.data != nil {
|
||||||
w.data = w.data[:0]
|
w.data = w.data[:0]
|
||||||
}
|
}
|
||||||
|
@ -135,10 +155,17 @@ func (w *SerializationContext) serialize(item Item) error {
|
||||||
if len(w.data)+v.end-v.start > MaxSize {
|
if len(w.data)+v.end-v.start > MaxSize {
|
||||||
return ErrTooBig
|
return ErrTooBig
|
||||||
}
|
}
|
||||||
|
w.limit -= v.itemsCount
|
||||||
|
if w.limit < 0 {
|
||||||
|
return errTooBigElements
|
||||||
|
}
|
||||||
w.data = append(w.data, w.data[v.start:v.end]...)
|
w.data = append(w.data, w.data[v.start:v.end]...)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
w.limit--
|
||||||
|
if w.limit < 0 {
|
||||||
|
return errTooBigElements
|
||||||
|
}
|
||||||
start := len(w.data)
|
start := len(w.data)
|
||||||
switch t := item.(type) {
|
switch t := item.(type) {
|
||||||
case *ByteArray:
|
case *ByteArray:
|
||||||
|
@ -188,6 +215,7 @@ func (w *SerializationContext) serialize(item Item) error {
|
||||||
}
|
}
|
||||||
case *Map:
|
case *Map:
|
||||||
w.seen[item] = sliceNoPointer{}
|
w.seen[item] = sliceNoPointer{}
|
||||||
|
limit := w.limit
|
||||||
|
|
||||||
elems := t.value
|
elems := t.value
|
||||||
w.data = append(w.data, byte(MapT))
|
w.data = append(w.data, byte(MapT))
|
||||||
|
@ -200,7 +228,7 @@ func (w *SerializationContext) serialize(item Item) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.seen[item] = sliceNoPointer{start, len(w.data)}
|
w.seen[item] = sliceNoPointer{start, len(w.data), limit - w.limit + 1} // number of items including Map itself.
|
||||||
case Null:
|
case Null:
|
||||||
w.data = append(w.data, byte(AnyT))
|
w.data = append(w.data, byte(AnyT))
|
||||||
case nil:
|
case nil:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package stackitem
|
package stackitem
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
@ -23,7 +24,19 @@ func TestSerializationMaxErr(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerialize(t *testing.T, expectedErr error, item Item) {
|
func testSerialize(t *testing.T, expectedErr error, item Item) {
|
||||||
data, err := Serialize(item)
|
testSerializeLimited(t, expectedErr, item, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSerializeLimited(t *testing.T, expectedErr error, item Item, limit int) {
|
||||||
|
var (
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if limit > 0 {
|
||||||
|
data, err = SerializeLimited(item, limit)
|
||||||
|
} else {
|
||||||
|
data, err = Serialize(item)
|
||||||
|
}
|
||||||
if expectedErr != nil {
|
if expectedErr != nil {
|
||||||
require.ErrorIs(t, err, expectedErr)
|
require.ErrorIs(t, err, expectedErr)
|
||||||
return
|
return
|
||||||
|
@ -58,7 +71,9 @@ func TestSerialize(t *testing.T) {
|
||||||
testSerialize(t, nil, newItem(items))
|
testSerialize(t, nil, newItem(items))
|
||||||
|
|
||||||
items = append(items, zeroByteArray)
|
items = append(items, zeroByteArray)
|
||||||
data, err := Serialize(newItem(items))
|
_, err := Serialize(newItem(items))
|
||||||
|
require.ErrorIs(t, err, errTooBigElements)
|
||||||
|
data, err := SerializeLimited(newItem(items), MaxSerialized+1) // a tiny hack to check deserialization error.
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = Deserialize(data)
|
_, err = Deserialize(data)
|
||||||
require.ErrorIs(t, err, ErrTooBig)
|
require.ErrorIs(t, err, ErrTooBig)
|
||||||
|
@ -165,13 +180,70 @@ func TestSerialize(t *testing.T) {
|
||||||
for i := 0; i <= MaxDeserialized; i++ {
|
for i := 0; i <= MaxDeserialized; i++ {
|
||||||
m.Add(Make(i), zeroByteArray)
|
m.Add(Make(i), zeroByteArray)
|
||||||
}
|
}
|
||||||
data, err := Serialize(m)
|
_, err := Serialize(m)
|
||||||
|
require.ErrorIs(t, err, errTooBigElements)
|
||||||
|
data, err := SerializeLimited(m, (MaxSerialized+1)*2+1) // a tiny hack to check deserialization error.
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = Deserialize(data)
|
_, err = Deserialize(data)
|
||||||
require.ErrorIs(t, err, ErrTooBig)
|
require.ErrorIs(t, err, ErrTooBig)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSerializeLimited(t *testing.T) {
|
||||||
|
const customLimit = 10
|
||||||
|
|
||||||
|
smallArray := make([]Item, customLimit-1)
|
||||||
|
for i := range smallArray {
|
||||||
|
smallArray[i] = NewBool(true)
|
||||||
|
}
|
||||||
|
bigArray := make([]Item, customLimit)
|
||||||
|
for i := range bigArray {
|
||||||
|
bigArray[i] = NewBool(true)
|
||||||
|
}
|
||||||
|
t.Run("array", func(t *testing.T) {
|
||||||
|
testSerializeLimited(t, nil, NewArray(smallArray), customLimit)
|
||||||
|
testSerializeLimited(t, errTooBigElements, NewArray(bigArray), customLimit)
|
||||||
|
})
|
||||||
|
t.Run("struct", func(t *testing.T) {
|
||||||
|
testSerializeLimited(t, nil, NewStruct(smallArray), customLimit)
|
||||||
|
testSerializeLimited(t, errTooBigElements, NewStruct(bigArray), customLimit)
|
||||||
|
})
|
||||||
|
t.Run("map", func(t *testing.T) {
|
||||||
|
smallMap := make([]MapElement, (customLimit-1)/2)
|
||||||
|
for i := range smallMap {
|
||||||
|
smallMap[i] = MapElement{
|
||||||
|
Key: NewByteArray([]byte(strconv.Itoa(i))),
|
||||||
|
Value: NewBool(true),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bigMap := make([]MapElement, customLimit/2)
|
||||||
|
for i := range bigMap {
|
||||||
|
bigMap[i] = MapElement{
|
||||||
|
Key: NewByteArray([]byte("key")),
|
||||||
|
Value: NewBool(true),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
testSerializeLimited(t, nil, NewMapWithValue(smallMap), customLimit)
|
||||||
|
testSerializeLimited(t, errTooBigElements, NewMapWithValue(bigMap), customLimit)
|
||||||
|
})
|
||||||
|
t.Run("seen", func(t *testing.T) {
|
||||||
|
t.Run("OK", func(t *testing.T) {
|
||||||
|
tinyArray := NewArray(make([]Item, (customLimit-3)/2)) // 1 for outer array, 1+1 for two inner arrays and the rest are for arrays' elements.
|
||||||
|
for i := range tinyArray.value {
|
||||||
|
tinyArray.value[i] = NewBool(true)
|
||||||
|
}
|
||||||
|
testSerializeLimited(t, nil, NewArray([]Item{tinyArray, tinyArray}), customLimit)
|
||||||
|
})
|
||||||
|
t.Run("big", func(t *testing.T) {
|
||||||
|
tinyArray := NewArray(make([]Item, (customLimit-2)/2)) // should break on the second array serialisation.
|
||||||
|
for i := range tinyArray.value {
|
||||||
|
tinyArray.value[i] = NewBool(true)
|
||||||
|
}
|
||||||
|
testSerializeLimited(t, errTooBigElements, NewArray([]Item{tinyArray, tinyArray}), customLimit)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestEmptyDeserialization(t *testing.T) {
|
func TestEmptyDeserialization(t *testing.T) {
|
||||||
empty := []byte{}
|
empty := []byte{}
|
||||||
_, err := Deserialize(empty)
|
_, err := Deserialize(empty)
|
||||||
|
@ -202,7 +274,7 @@ func TestDeserializeTooManyElements(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
item = Make([]Item{item})
|
item = Make([]Item{item})
|
||||||
data, err = Serialize(item)
|
data, err = SerializeLimited(item, MaxSerialized+1) // tiny hack to avoid serialization error.
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = Deserialize(data)
|
_, err = Deserialize(data)
|
||||||
require.ErrorIs(t, err, ErrTooBig)
|
require.ErrorIs(t, err, ErrTooBig)
|
||||||
|
@ -214,14 +286,14 @@ func TestDeserializeLimited(t *testing.T) {
|
||||||
for i := 0; i < customLimit-1; i++ { // 1 for zero inner element.
|
for i := 0; i < customLimit-1; i++ { // 1 for zero inner element.
|
||||||
item = Make([]Item{item})
|
item = Make([]Item{item})
|
||||||
}
|
}
|
||||||
data, err := Serialize(item)
|
data, err := SerializeLimited(item, customLimit) // tiny hack to avoid serialization error.
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
actual, err := DeserializeLimited(data, customLimit)
|
actual, err := DeserializeLimited(data, customLimit)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, item, actual)
|
require.Equal(t, item, actual)
|
||||||
|
|
||||||
item = Make([]Item{item})
|
item = Make([]Item{item})
|
||||||
data, err = Serialize(item)
|
data, err = SerializeLimited(item, customLimit+1) // tiny hack to avoid serialization error.
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = DeserializeLimited(data, customLimit)
|
_, err = DeserializeLimited(data, customLimit)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
Loading…
Reference in a new issue