diff --git a/internal/testserdes/testing.go b/internal/testserdes/testing.go index 0be71a844..d30e1ddaa 100644 --- a/internal/testserdes/testing.go +++ b/internal/testserdes/testing.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/require" ) @@ -26,6 +27,15 @@ func EncodeDecodeBinary(t *testing.T, expected, actual io.Serializable) { require.Equal(t, expected, actual) } +// ToFromStackItem checks if expected stays the same after converting to/from +// StackItem. +func ToFromStackItem(t *testing.T, expected, actual stackitem.Convertible) { + item, err := expected.ToStackItem() + require.NoError(t, err) + require.NoError(t, actual.FromStackItem(item)) + require.Equal(t, expected, actual) +} + // EncodeBinary serializes a to a byte slice. func EncodeBinary(a io.Serializable) ([]byte, error) { w := io.NewBufBinWriter() diff --git a/pkg/compiler/assign_test.go b/pkg/compiler/assign_test.go index 4c7007a0f..182224ea4 100644 --- a/pkg/compiler/assign_test.go +++ b/pkg/compiler/assign_test.go @@ -3,8 +3,6 @@ package compiler_test import ( "math/big" "testing" - - "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) var assignTestCases = []testCase{ @@ -154,9 +152,9 @@ func TestManyAssignments(t *testing.T) { src2 := `return a }` - for i := 0; i < stackitem.MaxArraySize; i++ { + for i := 0; i < 1024; i++ { src1 += "a += 1\n" } - eval(t, src1+src2, big.NewInt(stackitem.MaxArraySize)) + eval(t, src1+src2, big.NewInt(1024)) } diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index c62a3e202..5a820a967 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -361,24 +361,18 @@ func (bc *Blockchain) init() error { if storedCS == nil { return fmt.Errorf("native contract %s is not stored", md.Name) } - w := io.NewBufBinWriter() - storedCS.EncodeBinary(w.BinWriter) - if w.Err != nil { - return fmt.Errorf("failed to check native %s state against autogenerated one: %w", md.Name, w.Err) + storedCSBytes, err := stackitem.SerializeConvertible(storedCS) + if err != nil { + return fmt.Errorf("failed to check native %s state against autogenerated one: %w", md.Name, err) } - buff := w.Bytes() - storedCSBytes := make([]byte, len(buff)) - copy(storedCSBytes, buff) - w.Reset() autogenCS := &state.Contract{ ContractBase: md.ContractBase, UpdateCounter: storedCS.UpdateCounter, // it can be restored only from the DB, so use the stored value. } - autogenCS.EncodeBinary(w.BinWriter) - if w.Err != nil { - return fmt.Errorf("failed to check native %s state against autogenerated one: %w", md.Name, w.Err) + autogenCSBytes, err := stackitem.SerializeConvertible(autogenCS) + if err != nil { + return fmt.Errorf("failed to check native %s state against autogenerated one: %w", md.Name, err) } - autogenCSBytes := w.Bytes() if !bytes.Equal(storedCSBytes, autogenCSBytes) { return fmt.Errorf("native %s: version mismatch (stored contract state differs from autogenerated one), "+ "try to resynchronize the node from the genesis", md.Name) diff --git a/pkg/core/fee/opcode.go b/pkg/core/fee/opcode.go index b9c788dbc..806ee1a8d 100644 --- a/pkg/core/fee/opcode.go +++ b/pkg/core/fee/opcode.go @@ -164,7 +164,7 @@ var coefficients = map[opcode.Opcode]int64{ opcode.DIV: 1 << 3, opcode.MOD: 1 << 3, opcode.POW: 1 << 6, - opcode.SQRT: 1 << 11, + opcode.SQRT: 1 << 6, opcode.SHL: 1 << 3, opcode.SHR: 1 << 3, opcode.NOT: 1 << 2, diff --git a/pkg/core/interop_system.go b/pkg/core/interop_system.go index e7578180f..e83e3880f 100644 --- a/pkg/core/interop_system.go +++ b/pkg/core/interop_system.go @@ -219,8 +219,9 @@ func storageFind(ic *interop.Context) error { // given m and a set of public keys. func contractCreateMultisigAccount(ic *interop.Context) error { m := ic.VM.Estack().Pop().BigInt() - if !m.IsInt64() || m.Int64() > math.MaxInt32 { - return errors.New("m should fit int32") + mu64 := m.Uint64() + if !m.IsUint64() || mu64 > math.MaxInt32 { + return errors.New("m must be positive and fit int32") } arr := ic.VM.Estack().Pop().Array() pubs := make(keys.PublicKeys, len(arr)) @@ -231,7 +232,7 @@ func contractCreateMultisigAccount(ic *interop.Context) error { } pubs[i] = p } - script, err := smartcontract.CreateMultiSigRedeemScript(int(m.Int64()), pubs) + script, err := smartcontract.CreateMultiSigRedeemScript(int(mu64), pubs) if err != nil { return err } diff --git a/pkg/core/native/designate.go b/pkg/core/native/designate.go index 05f387dec..8c75eebc8 100644 --- a/pkg/core/native/designate.go +++ b/pkg/core/native/designate.go @@ -18,7 +18,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" @@ -272,10 +271,9 @@ func (s *Designate) GetDesignatedByRole(d dao.DAO, r noderoles.Role, index uint3 } } if resSi != nil { - reader := io.NewBinReaderFromBuf(resSi) - ns.DecodeBinary(reader) - if reader.Err != nil { - return nil, 0, reader.Err + err = stackitem.DeserializeConvertible(resSi, &ns) + if err != nil { + return nil, 0, err } } return keys.PublicKeys(ns), bestIndex, err @@ -287,7 +285,7 @@ func (s *Designate) designateAsRole(ic *interop.Context, args []stackitem.Item) panic(ErrInvalidRole) } var ns NodeList - if err := ns.fromStackItem(args[1]); err != nil { + if err := ns.FromStackItem(args[1]); err != nil { panic(err) } @@ -326,8 +324,9 @@ func (s *Designate) DesignateAsRole(ic *interop.Context, r noderoles.Role, pubs return ErrAlreadyDesignated } sort.Sort(pubs) + nl := NodeList(pubs) s.rolesChangedFlag.Store(true) - err := ic.DAO.PutStorageItem(s.ID, key, NodeList(pubs).Bytes()) + err := putConvertibleToDAO(s.ID, ic.DAO, key, &nl) if err != nil { return err } diff --git a/pkg/core/native/ledger.go b/pkg/core/native/ledger.go index 969bf26dc..8c901162c 100644 --- a/pkg/core/native/ledger.go +++ b/pkg/core/native/ledger.go @@ -157,9 +157,9 @@ func isTraceableBlock(bc blockchainer.Blockchainer, index uint32) bool { // be called within VM context, so it panics if anything goes wrong. func getBlockHashFromItem(bc blockchainer.Blockchainer, item stackitem.Item) util.Uint256 { bigindex, err := item.TryInteger() - if err == nil && bigindex.IsInt64() { - index := bigindex.Int64() - if index < 0 || index > math.MaxUint32 { + if err == nil && bigindex.IsUint64() { + index := bigindex.Uint64() + if index > math.MaxUint32 { panic("bad block index") } if uint32(index) > bc.BlockHeight() { diff --git a/pkg/core/native/management.go b/pkg/core/native/management.go index 032bf3571..a5eb71c52 100644 --- a/pkg/core/native/management.go +++ b/pkg/core/native/management.go @@ -16,7 +16,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" @@ -155,7 +154,7 @@ func (m *Management) GetContract(d dao.DAO, hash util.Uint160) (*state.Contract, func (m *Management) getContractFromDAO(d dao.DAO, hash util.Uint160) (*state.Contract, error) { contract := new(state.Contract) key := makeContractKey(hash) - err := getSerializableFromDAO(m.ID, d, key, contract) + err := getConvertibleFromDAO(m.ID, d, key, contract) if err != nil { return nil, err } @@ -487,14 +486,12 @@ func (m *Management) InitializeCache(d dao.DAO) error { var initErr error d.Seek(m.ID, []byte{prefixContract}, func(_, v []byte) { - var cs state.Contract - r := io.NewBinReaderFromBuf(v) - cs.DecodeBinary(r) - if r.Err != nil { - initErr = r.Err + var cs = new(state.Contract) + initErr = stackitem.DeserializeConvertible(v, cs) + if initErr != nil { return } - m.contracts[cs.Hash] = &cs + m.contracts[cs.Hash] = cs }) return initErr } @@ -529,7 +526,7 @@ func (m *Management) Initialize(ic *interop.Context) error { // PutContractState saves given contract state into given DAO. func (m *Management) PutContractState(d dao.DAO, cs *state.Contract) error { key := makeContractKey(cs.Hash) - if err := putSerializableToDAO(m.ID, d, key, cs); err != nil { + if err := putConvertibleToDAO(m.ID, d, key, cs); err != nil { return err } m.markUpdated(cs.Hash) diff --git a/pkg/core/native/native_neo.go b/pkg/core/native/native_neo.go index dadeaee35..3726d9063 100644 --- a/pkg/core/native/native_neo.go +++ b/pkg/core/native/native_neo.go @@ -20,7 +20,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" @@ -698,7 +697,7 @@ func (n *NEO) RegisterCandidateInternal(ic *interop.Context, pub *keys.PublicKey c = new(candidate).FromBytes(si) c.Registered = true } - return ic.DAO.PutStorageItem(n.ID, key, c.Bytes()) + return putConvertibleToDAO(n.ID, ic.DAO, key, c) } func (n *NEO) unregisterCandidate(ic *interop.Context, args []stackitem.Item) stackitem.Item { @@ -727,7 +726,7 @@ func (n *NEO) UnregisterCandidateInternal(ic *interop.Context, pub *keys.PublicK if ok { return err } - return ic.DAO.PutStorageItem(n.ID, key, c.Bytes()) + return putConvertibleToDAO(n.ID, ic.DAO, key, c) } func (n *NEO) vote(ic *interop.Context, args []stackitem.Item) stackitem.Item { @@ -820,7 +819,7 @@ func (n *NEO) ModifyAccountVotes(acc *state.NEOBalanceState, d dao.DAO, value *b } } n.validators.Store(keys.PublicKeys(nil)) - return d.PutStorageItem(n.ID, key, cd.Bytes()) + return putConvertibleToDAO(n.ID, d, key, cd) } return nil } @@ -903,10 +902,9 @@ func (n *NEO) getAccountState(ic *interop.Context, args []stackitem.Item) stacki return stackitem.Null{} } - r := io.NewBinReaderFromBuf(si) - item := stackitem.DecodeBinary(r) - if r.Err != nil { - panic(r.Err) // no errors are expected but we better be sure + item, err := stackitem.Deserialize(si) + if err != nil { + panic(err) // no errors are expected but we better be sure } return item } diff --git a/pkg/core/native/native_neo_candidate.go b/pkg/core/native/native_neo_candidate.go index a1a7bb44f..8a1d6ca84 100644 --- a/pkg/core/native/native_neo_candidate.go +++ b/pkg/core/native/native_neo_candidate.go @@ -3,7 +3,6 @@ package native import ( "math/big" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -12,40 +11,35 @@ type candidate struct { Votes big.Int } -// Bytes marshals c to byte array. -func (c *candidate) Bytes() []byte { - w := io.NewBufBinWriter() - stackitem.EncodeBinary(c.toStackItem(), w.BinWriter) - return w.Bytes() -} - // FromBytes unmarshals candidate from byte array. func (c *candidate) FromBytes(data []byte) *candidate { - r := io.NewBinReaderFromBuf(data) - item := stackitem.DecodeBinary(r) - if r.Err != nil { - panic(r.Err) + err := stackitem.DeserializeConvertible(data, c) + if err != nil { + panic(err) } - return c.fromStackItem(item) + return c } -func (c *candidate) toStackItem() stackitem.Item { +// ToStackItem implements stackitem.Convertible. It never returns an error. +func (c *candidate) ToStackItem() (stackitem.Item, error) { return stackitem.NewStruct([]stackitem.Item{ stackitem.NewBool(c.Registered), stackitem.NewBigInteger(&c.Votes), - }) + }), nil } -func (c *candidate) fromStackItem(item stackitem.Item) *candidate { +// FromStackItem implements stackitem.Convertible. +func (c *candidate) FromStackItem(item stackitem.Item) error { arr := item.(*stackitem.Struct).Value().([]stackitem.Item) vs, err := arr[1].TryInteger() if err != nil { - panic(err) + return err } - c.Registered, err = arr[0].TryBool() + reg, err := arr[0].TryBool() if err != nil { - panic(err) + return err } + c.Registered = reg c.Votes = *vs - return c + return nil } diff --git a/pkg/core/native/native_neo_test.go b/pkg/core/native/native_neo_test.go index bccd3bd42..ebeb30742 100644 --- a/pkg/core/native/native_neo_test.go +++ b/pkg/core/native/native_neo_test.go @@ -4,7 +4,7 @@ import ( "math/big" "testing" - "github.com/stretchr/testify/require" + "github.com/nspcc-dev/neo-go/internal/testserdes" ) func TestCandidate_Bytes(t *testing.T) { @@ -12,7 +12,6 @@ func TestCandidate_Bytes(t *testing.T) { Registered: true, Votes: *big.NewInt(0x0F), } - data := expected.Bytes() - actual := new(candidate).FromBytes(data) - require.Equal(t, expected, actual) + actual := new(candidate) + testserdes.ToFromStackItem(t, expected, actual) } diff --git a/pkg/core/native/native_nep17.go b/pkg/core/native/native_nep17.go index 0bf47a615..3a6612121 100644 --- a/pkg/core/native/native_nep17.go +++ b/pkg/core/native/native_nep17.go @@ -343,12 +343,12 @@ func toUint160(s stackitem.Item) util.Uint160 { func toUint32(s stackitem.Item) uint32 { bigInt := toBigInt(s) - if !bigInt.IsInt64() { - panic("bigint is not an int64") + if !bigInt.IsUint64() { + panic("bigint is not an uint64") } - int64Value := bigInt.Int64() - if int64Value < 0 || int64Value > math.MaxUint32 { + uint64Value := bigInt.Uint64() + if uint64Value > math.MaxUint32 { panic("bigint does not fit into uint32") } - return uint32(int64Value) + return uint32(uint64Value) } diff --git a/pkg/core/native/neo_types.go b/pkg/core/native/neo_types.go index 15dcf1e97..21e5dfb3b 100644 --- a/pkg/core/native/neo_types.go +++ b/pkg/core/native/neo_types.go @@ -6,7 +6,6 @@ import ( "math/big" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -84,21 +83,18 @@ func (k *keysWithVotes) fromStackItem(item stackitem.Item) error { // Bytes serializes keys with votes slice. func (k keysWithVotes) Bytes() []byte { - var it = k.toStackItem() - var w = io.NewBufBinWriter() - stackitem.EncodeBinary(it, w.BinWriter) - if w.Err != nil { - panic(w.Err) + buf, err := stackitem.Serialize(k.toStackItem()) + if err != nil { + panic(err) } - return w.Bytes() + return buf } // DecodeBytes deserializes keys and votes slice. func (k *keysWithVotes) DecodeBytes(data []byte) error { - var r = io.NewBinReaderFromBuf(data) - var it = stackitem.DecodeBinary(r) - if r.Err != nil { - return r.Err + it, err := stackitem.Deserialize(data) + if err != nil { + return err } return k.fromStackItem(it) } diff --git a/pkg/core/native/notary.go b/pkg/core/native/notary.go index e29224a0c..36d268a3e 100644 --- a/pkg/core/native/notary.go +++ b/pkg/core/native/notary.go @@ -414,7 +414,7 @@ func (n *Notary) setMaxNotValidBeforeDelta(ic *interop.Context, args []stackitem func (n *Notary) GetDepositFor(dao dao.DAO, acc util.Uint160) *state.Deposit { key := append([]byte{prefixDeposit}, acc.BytesBE()...) deposit := new(state.Deposit) - err := getSerializableFromDAO(n.ID, dao, key, deposit) + err := getConvertibleFromDAO(n.ID, dao, key, deposit) if err == nil { return deposit } @@ -427,7 +427,7 @@ func (n *Notary) GetDepositFor(dao dao.DAO, acc util.Uint160) *state.Deposit { // putDepositFor puts deposit on the balance of the specified account in the storage. func (n *Notary) putDepositFor(dao dao.DAO, deposit *state.Deposit, acc util.Uint160) error { key := append([]byte{prefixDeposit}, acc.BytesBE()...) - return putSerializableToDAO(n.ID, dao, key, deposit) + return putConvertibleToDAO(n.ID, dao, key, deposit) } // removeDepositFor removes deposit from the storage. diff --git a/pkg/core/native/oracle.go b/pkg/core/native/oracle.go index 7c500f6cf..9f7ead2ac 100644 --- a/pkg/core/native/oracle.go +++ b/pkg/core/native/oracle.go @@ -170,7 +170,7 @@ func (o *Oracle) PostPersist(ic *interop.Context) error { } reqKey := makeRequestKey(resp.ID) req := new(state.OracleRequest) - if err := o.getSerializableFromDAO(ic.DAO, reqKey, req); err != nil { + if err := o.getConvertibleFromDAO(ic.DAO, reqKey, req); err != nil { continue } if err := ic.DAO.DeleteStorageItem(o.ID, reqKey); err != nil { @@ -182,7 +182,7 @@ func (o *Oracle) PostPersist(ic *interop.Context) error { idKey := makeIDListKey(req.URL) idList := new(IDList) - if err := o.getSerializableFromDAO(ic.DAO, idKey, idList); err != nil { + if err := o.getConvertibleFromDAO(ic.DAO, idKey, idList); err != nil { return err } if !idList.Remove(resp.ID) { @@ -193,7 +193,7 @@ func (o *Oracle) PostPersist(ic *interop.Context) error { if len(*idList) == 0 { err = ic.DAO.DeleteStorageItem(o.ID, idKey) } else { - err = ic.DAO.PutStorageItem(o.ID, idKey, idList.Bytes()) + err = putConvertibleToDAO(o.ID, ic.DAO, idKey, idList) } if err != nil { return err @@ -277,11 +277,13 @@ func (o *Oracle) FinishInternal(ic *interop.Context) error { }), }) - r := io.NewBinReaderFromBuf(req.UserData) - userData := stackitem.DecodeBinary(r) + userData, err := stackitem.Deserialize(req.UserData) + if err != nil { + return err + } args := []stackitem.Item{ stackitem.Make(req.URL), - stackitem.Make(userData), + userData, stackitem.Make(resp.Code), stackitem.Make(resp.Result), } @@ -327,10 +329,10 @@ func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.I // RequestInternal processes oracle request. func (o *Oracle) RequestInternal(ic *interop.Context, url string, filter *string, cb string, userData stackitem.Item, gas *big.Int) error { - if len(url) > maxURLLength || (filter != nil && len(*filter) > maxFilterLength) || len(cb) > maxCallbackLength { + if len(url) > maxURLLength || (filter != nil && len(*filter) > maxFilterLength) || len(cb) > maxCallbackLength || !gas.IsInt64() { return ErrBigArgument } - if gas.Uint64() < MinimumResponseGas { + if gas.Int64() < MinimumResponseGas { return ErrLowResponseGas } if strings.HasPrefix(cb, "_") { @@ -357,12 +359,10 @@ func (o *Oracle) RequestInternal(ic *interop.Context, url string, filter *string return err } - w := io.NewBufBinWriter() - stackitem.EncodeBinary(userData, w.BinWriter) - if w.Err != nil { - return w.Err + data, err := stackitem.Serialize(userData) + if err != nil { + return err } - data := w.Bytes() if len(data) > maxUserDataLength { return ErrBigArgument } @@ -398,7 +398,7 @@ func (o *Oracle) RequestInternal(ic *interop.Context, url string, filter *string // PutRequestInternal puts oracle request with the specified id to d. func (o *Oracle) PutRequestInternal(id uint64, req *state.OracleRequest, d dao.DAO) error { reqKey := makeRequestKey(id) - if err := d.PutStorageItem(o.ID, reqKey, req.Bytes()); err != nil { + if err := putConvertibleToDAO(o.ID, d, reqKey, req); err != nil { return err } o.newRequests[id] = req @@ -406,14 +406,14 @@ func (o *Oracle) PutRequestInternal(id uint64, req *state.OracleRequest, d dao.D // Add request ID to the id list. lst := new(IDList) key := makeIDListKey(req.URL) - if err := o.getSerializableFromDAO(d, key, lst); err != nil && !errors.Is(err, storage.ErrKeyNotFound) { + if err := o.getConvertibleFromDAO(d, key, lst); err != nil && !errors.Is(err, storage.ErrKeyNotFound) { return err } if len(*lst) >= maxRequestsCount { return fmt.Errorf("there are too many pending requests for %s url", req.URL) } *lst = append(*lst, id) - return d.PutStorageItem(o.ID, key, lst.Bytes()) + return putConvertibleToDAO(o.ID, d, key, lst) } // GetScriptHash returns script hash or oracle nodes. @@ -431,14 +431,14 @@ func (o *Oracle) GetOracleNodes(d dao.DAO) (keys.PublicKeys, error) { func (o *Oracle) GetRequestInternal(d dao.DAO, id uint64) (*state.OracleRequest, error) { key := makeRequestKey(id) req := new(state.OracleRequest) - return req, o.getSerializableFromDAO(d, key, req) + return req, o.getConvertibleFromDAO(d, key, req) } // GetIDListInternal returns request by ID and key under which it is stored. func (o *Oracle) GetIDListInternal(d dao.DAO, url string) (*IDList, error) { key := makeIDListKey(url) idList := new(IDList) - return idList, o.getSerializableFromDAO(d, key, idList) + return idList, o.getConvertibleFromDAO(d, key, idList) } func (o *Oracle) verify(ic *interop.Context, _ []stackitem.Item) stackitem.Item { @@ -493,11 +493,10 @@ func (o *Oracle) getRequests(d dao.DAO) (map[uint64]*state.OracleRequest, error) if len(k) != 8 { return nil, errors.New("invalid request ID") } - r := io.NewBinReaderFromBuf(si) req := new(state.OracleRequest) - req.DecodeBinary(r) - if r.Err != nil { - return nil, r.Err + err = stackitem.DeserializeConvertible(si, req) + if err != nil { + return nil, err } id := binary.BigEndian.Uint64([]byte(k)) reqs[id] = req @@ -516,8 +515,8 @@ func makeIDListKey(url string) []byte { return append(prefixIDList, hash.Hash160([]byte(url)).BytesBE()...) } -func (o *Oracle) getSerializableFromDAO(d dao.DAO, key []byte, item io.Serializable) error { - return getSerializableFromDAO(o.ID, d, key, item) +func (o *Oracle) getConvertibleFromDAO(d dao.DAO, key []byte, item stackitem.Convertible) error { + return getConvertibleFromDAO(o.ID, d, key, item) } // updateCache updates cached Oracle values if they've been changed. diff --git a/pkg/core/native/oracle_types.go b/pkg/core/native/oracle_types.go index 985897e72..b678f67db 100644 --- a/pkg/core/native/oracle_types.go +++ b/pkg/core/native/oracle_types.go @@ -6,7 +6,6 @@ import ( "math/big" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -16,36 +15,17 @@ type IDList []uint64 // NodeList represents list or oracle nodes. type NodeList keys.PublicKeys -// Bytes return l serizalized to a byte-slice. -func (l IDList) Bytes() []byte { - w := io.NewBufBinWriter() - l.EncodeBinary(w.BinWriter) - return w.Bytes() -} - -// EncodeBinary implements io.Serializable. -func (l IDList) EncodeBinary(w *io.BinWriter) { - stackitem.EncodeBinary(l.toStackItem(), w) -} - -// DecodeBinary implements io.Serializable. -func (l *IDList) DecodeBinary(r *io.BinReader) { - item := stackitem.DecodeBinary(r) - if r.Err != nil || item == nil { - return - } - r.Err = l.fromStackItem(item) -} - -func (l IDList) toStackItem() stackitem.Item { +// ToStackItem implements stackitem.Convertible. It never returns an error. +func (l IDList) ToStackItem() (stackitem.Item, error) { arr := make([]stackitem.Item, len(l)) for i := range l { arr[i] = stackitem.NewBigInteger(new(big.Int).SetUint64(l[i])) } - return stackitem.NewArray(arr) + return stackitem.NewArray(arr), nil } -func (l *IDList) fromStackItem(it stackitem.Item) error { +// FromStackItem implements stackitem.Convertible. +func (l *IDList) FromStackItem(it stackitem.Item) error { arr, ok := it.Value().([]stackitem.Item) if !ok { return errors.New("not an array") @@ -75,36 +55,17 @@ func (l *IDList) Remove(id uint64) bool { return false } -// Bytes return l serizalized to a byte-slice. -func (l NodeList) Bytes() []byte { - w := io.NewBufBinWriter() - l.EncodeBinary(w.BinWriter) - return w.Bytes() -} - -// EncodeBinary implements io.Serializable. -func (l NodeList) EncodeBinary(w *io.BinWriter) { - stackitem.EncodeBinary(l.toStackItem(), w) -} - -// DecodeBinary implements io.Serializable. -func (l *NodeList) DecodeBinary(r *io.BinReader) { - item := stackitem.DecodeBinary(r) - if r.Err != nil || item == nil { - return - } - r.Err = l.fromStackItem(item) -} - -func (l NodeList) toStackItem() stackitem.Item { +// ToStackItem implements stackitem.Convertible. It never returns an error. +func (l NodeList) ToStackItem() (stackitem.Item, error) { arr := make([]stackitem.Item, len(l)) for i := range l { arr[i] = stackitem.NewByteArray(l[i].Bytes()) } - return stackitem.NewArray(arr) + return stackitem.NewArray(arr), nil } -func (l *NodeList) fromStackItem(it stackitem.Item) error { +// FromStackItem implements stackitem.Convertible. +func (l *NodeList) FromStackItem(it stackitem.Item) error { arr, ok := it.Value().([]stackitem.Item) if !ok { return errors.New("not an array") diff --git a/pkg/core/native/oracle_types_test.go b/pkg/core/native/oracle_types_test.go index c58323dd9..cba0e20b1 100644 --- a/pkg/core/native/oracle_types_test.go +++ b/pkg/core/native/oracle_types_test.go @@ -5,32 +5,26 @@ import ( "github.com/nspcc-dev/neo-go/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/require" ) -func getInvalidTestFunc(actual io.Serializable, value interface{}) func(t *testing.T) { +func getInvalidTestFunc(actual stackitem.Convertible, value interface{}) func(t *testing.T) { return func(t *testing.T) { - w := io.NewBufBinWriter() it := stackitem.Make(value) - stackitem.EncodeBinary(it, w.BinWriter) - require.NoError(t, w.Err) - require.Error(t, testserdes.DecodeBinary(w.Bytes(), actual)) + require.Error(t, actual.FromStackItem(it)) } } -func TestIDList_EncodeBinary(t *testing.T) { +func TestIDListToFromSI(t *testing.T) { t.Run("Valid", func(t *testing.T) { l := &IDList{1, 4, 5} - testserdes.EncodeDecodeBinary(t, l, new(IDList)) + var l2 = new(IDList) + testserdes.ToFromStackItem(t, l, l2) }) t.Run("Invalid", func(t *testing.T) { t.Run("NotArray", getInvalidTestFunc(new(IDList), []byte{})) t.Run("InvalidElement", getInvalidTestFunc(new(IDList), []stackitem.Item{stackitem.Null{}})) - t.Run("NotStackItem", func(t *testing.T) { - require.Error(t, testserdes.DecodeBinary([]byte{0x77}, new(IDList))) - }) }) } @@ -50,22 +44,20 @@ func TestIDList_Remove(t *testing.T) { require.Equal(t, IDList{1}, l) } -func TestNodeList_EncodeBinary(t *testing.T) { +func TestNodeListToFromSI(t *testing.T) { priv, err := keys.NewPrivateKey() require.NoError(t, err) pub := priv.PublicKey() t.Run("Valid", func(t *testing.T) { l := &NodeList{pub} - testserdes.EncodeDecodeBinary(t, l, new(NodeList)) + var l2 = new(NodeList) + testserdes.ToFromStackItem(t, l, l2) }) t.Run("Invalid", func(t *testing.T) { t.Run("NotArray", getInvalidTestFunc(new(NodeList), []byte{})) t.Run("InvalidElement", getInvalidTestFunc(new(NodeList), []stackitem.Item{stackitem.Null{}})) t.Run("InvalidKey", getInvalidTestFunc(new(NodeList), []stackitem.Item{stackitem.NewByteArray([]byte{0x9})})) - t.Run("NotStackItem", func(t *testing.T) { - require.Error(t, testserdes.DecodeBinary([]byte{0x77}, new(NodeList))) - }) }) } diff --git a/pkg/core/native/std_test.go b/pkg/core/native/std_test.go index 2558b5f2d..2817850e0 100644 --- a/pkg/core/native/std_test.go +++ b/pkg/core/native/std_test.go @@ -11,7 +11,6 @@ import ( "github.com/mr-tron/base58" "github.com/nspcc-dev/neo-go/pkg/core/interop" base58neogo "github.com/nspcc-dev/neo-go/pkg/encoding/base58" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/assert" @@ -275,11 +274,8 @@ func TestStdLibSerialize(t *testing.T) { actualSerialized = s.serialize(ic, []stackitem.Item{stackitem.Make(42)}) }) - w := io.NewBufBinWriter() - stackitem.EncodeBinary(stackitem.Make(42), w.BinWriter) - require.NoError(t, w.Err) - - encoded := w.Bytes() + encoded, err := stackitem.Serialize(stackitem.Make(42)) + require.NoError(t, err) require.Equal(t, stackitem.Make(encoded), actualSerialized) require.NotPanics(t, func() { diff --git a/pkg/core/native/util.go b/pkg/core/native/util.go index c882762cb..afb20728c 100644 --- a/pkg/core/native/util.go +++ b/pkg/core/native/util.go @@ -8,30 +8,26 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) var intOne = big.NewInt(1) -func getSerializableFromDAO(id int32, d dao.DAO, key []byte, item io.Serializable) error { +func getConvertibleFromDAO(id int32, d dao.DAO, key []byte, conv stackitem.Convertible) error { si := d.GetStorageItem(id, key) if si == nil { return storage.ErrKeyNotFound } - r := io.NewBinReaderFromBuf(si) - item.DecodeBinary(r) - return r.Err + return stackitem.DeserializeConvertible(si, conv) } -func putSerializableToDAO(id int32, d dao.DAO, key []byte, item io.Serializable) error { - w := io.NewBufBinWriter() - item.EncodeBinary(w.BinWriter) - if w.Err != nil { - return w.Err +func putConvertibleToDAO(id int32, d dao.DAO, key []byte, conv stackitem.Convertible) error { + data, err := stackitem.SerializeConvertible(conv) + if err != nil { + return err } - return d.PutStorageItem(id, key, w.Bytes()) + return d.PutStorageItem(id, key, data) } func setIntWithKey(id int32, dao dao.DAO, key []byte, value int64) error { diff --git a/pkg/core/state/contract.go b/pkg/core/state/contract.go index e6ee4e873..9d7cf8514 100644 --- a/pkg/core/state/contract.go +++ b/pkg/core/state/contract.go @@ -35,25 +35,6 @@ type NativeContract struct { UpdateHistory []uint32 `json:"updatehistory"` } -// DecodeBinary implements Serializable interface. -func (c *Contract) DecodeBinary(r *io.BinReader) { - si := stackitem.DecodeBinary(r) - if r.Err != nil { - return - } - r.Err = c.FromStackItem(si) -} - -// EncodeBinary implements Serializable interface. -func (c *Contract) EncodeBinary(w *io.BinWriter) { - si, err := c.ToStackItem() - if err != nil { - w.Err = err - return - } - stackitem.EncodeBinary(si, w) -} - // ToStackItem converts state.Contract to stackitem.Item. func (c *Contract) ToStackItem() (stackitem.Item, error) { rawNef, err := c.NEF.Bytes() @@ -92,10 +73,10 @@ func (c *Contract) FromStackItem(item stackitem.Item) error { if !ok { return errors.New("UpdateCounter is not an integer") } - if !bi.IsInt64() || bi.Int64() > math.MaxUint16 || bi.Int64() < 0 { + if !bi.IsUint64() || bi.Uint64() > math.MaxUint16 { return errors.New("UpdateCounter not in uint16 range") } - c.UpdateCounter = uint16(bi.Int64()) + c.UpdateCounter = uint16(bi.Uint64()) bytes, err := arr[2].TryBytes() if err != nil { return err diff --git a/pkg/core/state/contract_test.go b/pkg/core/state/contract_test.go index d2c7987ec..3f872974a 100644 --- a/pkg/core/state/contract_test.go +++ b/pkg/core/state/contract_test.go @@ -14,7 +14,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestEncodeDecodeContractState(t *testing.T) { +func TestContractStateToFromSI(t *testing.T) { script := []byte("testscript") h := hash.Hash160(script) @@ -52,9 +52,9 @@ func TestEncodeDecodeContractState(t *testing.T) { } contract.NEF.Checksum = contract.NEF.CalculateChecksum() - t.Run("Serializable", func(t *testing.T) { + t.Run("Convertible", func(t *testing.T) { contractDecoded := new(Contract) - testserdes.EncodeDecodeBinary(t, contract, contractDecoded) + testserdes.ToFromStackItem(t, contract, contractDecoded) }) t.Run("JSON", func(t *testing.T) { contractDecoded := new(Contract) diff --git a/pkg/core/state/deposit.go b/pkg/core/state/deposit.go index 8e45e9595..59c97fe54 100644 --- a/pkg/core/state/deposit.go +++ b/pkg/core/state/deposit.go @@ -1,10 +1,12 @@ package state import ( + "errors" + "fmt" + "math" "math/big" - "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" - "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) // Deposit represents GAS deposit from Notary contract. @@ -13,14 +15,37 @@ type Deposit struct { Till uint32 } -// EncodeBinary implements io.Serializable interface. -func (d *Deposit) EncodeBinary(w *io.BinWriter) { - w.WriteVarBytes(bigint.ToBytes(d.Amount)) - w.WriteU32LE(d.Till) +// ToStackItem implements stackitem.Convertible interface. It never returns an +// error. +func (d *Deposit) ToStackItem() (stackitem.Item, error) { + return stackitem.NewStruct([]stackitem.Item{ + stackitem.NewBigInteger(d.Amount), + stackitem.Make(d.Till), + }), nil } -// DecodeBinary implements io.Serializable interface. -func (d *Deposit) DecodeBinary(r *io.BinReader) { - d.Amount = bigint.FromBytes(r.ReadVarBytes()) - d.Till = r.ReadU32LE() +// FromStackItem implements stackitem.Convertible interface. +func (d *Deposit) FromStackItem(it stackitem.Item) error { + items, ok := it.Value().([]stackitem.Item) + if !ok { + return errors.New("not a struct") + } + if len(items) != 2 { + return errors.New("wrong number of elements") + } + amount, err := items[0].TryInteger() + if err != nil { + return fmt.Errorf("invalid amount: %w", err) + } + till, err := items[1].TryInteger() + if err != nil { + return fmt.Errorf("invalid till: %w", err) + } + tiu64 := till.Uint64() + if !till.IsUint64() || tiu64 > math.MaxUint32 { + return errors.New("wrong till value") + } + d.Amount = amount + d.Till = uint32(tiu64) + return nil } diff --git a/pkg/core/state/deposit_test.go b/pkg/core/state/deposit_test.go new file mode 100644 index 000000000..0b19052e6 --- /dev/null +++ b/pkg/core/state/deposit_test.go @@ -0,0 +1,54 @@ +package state + +import ( + "math/big" + "testing" + + "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/stretchr/testify/require" +) + +func TestEncodeDecodeDeposit(t *testing.T) { + d := &Deposit{Amount: big.NewInt(100500), Till: 888} + depo := new(Deposit) + testserdes.ToFromStackItem(t, d, depo) +} + +func TestDepositFromStackItem(t *testing.T) { + var d Deposit + + item := stackitem.Make(42) + require.Error(t, d.FromStackItem(item)) + + item = stackitem.NewStruct(nil) + require.Error(t, d.FromStackItem(item)) + + item = stackitem.NewStruct([]stackitem.Item{ + stackitem.NewStruct(nil), + stackitem.NewStruct(nil), + }) + require.Error(t, d.FromStackItem(item)) + + item = stackitem.NewStruct([]stackitem.Item{ + stackitem.Make(777), + stackitem.NewStruct(nil), + }) + require.Error(t, d.FromStackItem(item)) + + item = stackitem.NewStruct([]stackitem.Item{ + stackitem.Make(777), + stackitem.Make(-1), + }) + require.Error(t, d.FromStackItem(item)) + item = stackitem.NewStruct([]stackitem.Item{ + stackitem.Make(777), + stackitem.Make("somenonu64value"), + }) + require.Error(t, d.FromStackItem(item)) + item = stackitem.NewStruct([]stackitem.Item{ + stackitem.Make(777), + stackitem.Make(888), + }) + require.NoError(t, d.FromStackItem(item)) +} diff --git a/pkg/core/state/native_state.go b/pkg/core/state/native_state.go index 75baee89d..a4f652e3f 100644 --- a/pkg/core/state/native_state.go +++ b/pkg/core/state/native_state.go @@ -7,7 +7,6 @@ import ( "math/big" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -26,99 +25,81 @@ type NEOBalanceState struct { // NEP17BalanceStateFromBytes converts serialized NEP17BalanceState to structure. func NEP17BalanceStateFromBytes(b []byte) (*NEP17BalanceState, error) { balance := new(NEP17BalanceState) - if len(b) == 0 { - return balance, nil - } - r := io.NewBinReaderFromBuf(b) - balance.DecodeBinary(r) - if r.Err != nil { - return nil, r.Err + err := balanceFromBytes(b, balance) + if err != nil { + return nil, err } return balance, nil } // Bytes returns serialized NEP17BalanceState. func (s *NEP17BalanceState) Bytes() []byte { - w := io.NewBufBinWriter() - s.EncodeBinary(w.BinWriter) - if w.Err != nil { - panic(w.Err) + return balanceToBytes(s) +} + +func balanceFromBytes(b []byte, item stackitem.Convertible) error { + if len(b) == 0 { + return nil } - return w.Bytes() + return stackitem.DeserializeConvertible(b, item) } -func (s *NEP17BalanceState) toStackItem() stackitem.Item { - return stackitem.NewStruct([]stackitem.Item{stackitem.NewBigInteger(&s.Balance)}) -} - -func (s *NEP17BalanceState) fromStackItem(item stackitem.Item) { - s.Balance = *item.(*stackitem.Struct).Value().([]stackitem.Item)[0].Value().(*big.Int) -} - -// EncodeBinary implements io.Serializable interface. -func (s *NEP17BalanceState) EncodeBinary(w *io.BinWriter) { - si := s.toStackItem() - stackitem.EncodeBinary(si, w) -} - -// DecodeBinary implements io.Serializable interface. -func (s *NEP17BalanceState) DecodeBinary(r *io.BinReader) { - si := stackitem.DecodeBinary(r) - if r.Err != nil { - return +func balanceToBytes(item stackitem.Convertible) []byte { + data, err := stackitem.SerializeConvertible(item) + if err != nil { + panic(err) } - s.fromStackItem(si) + return data +} + +// ToStackItem implements stackitem.Convertible. It never returns an error. +func (s *NEP17BalanceState) ToStackItem() (stackitem.Item, error) { + return stackitem.NewStruct([]stackitem.Item{stackitem.NewBigInteger(&s.Balance)}), nil +} + +// FromStackItem implements stackitem.Convertible. +func (s *NEP17BalanceState) FromStackItem(item stackitem.Item) error { + items, ok := item.Value().([]stackitem.Item) + if !ok { + return errors.New("not a struct") + } + if len(items) < 1 { + return errors.New("no balance value") + } + balance, err := items[0].TryInteger() + if err != nil { + return fmt.Errorf("invalid balance: %w", err) + } + s.Balance = *balance + return nil } // NEOBalanceStateFromBytes converts serialized NEOBalanceState to structure. func NEOBalanceStateFromBytes(b []byte) (*NEOBalanceState, error) { balance := new(NEOBalanceState) - if len(b) == 0 { - return balance, nil - } - r := io.NewBinReaderFromBuf(b) - balance.DecodeBinary(r) - - if r.Err != nil { - return nil, r.Err + err := balanceFromBytes(b, balance) + if err != nil { + return nil, err } return balance, nil } // Bytes returns serialized NEOBalanceState. func (s *NEOBalanceState) Bytes() []byte { - w := io.NewBufBinWriter() - s.EncodeBinary(w.BinWriter) - if w.Err != nil { - panic(w.Err) - } - return w.Bytes() + return balanceToBytes(s) } -// EncodeBinary implements io.Serializable interface. -func (s *NEOBalanceState) EncodeBinary(w *io.BinWriter) { - si := s.toStackItem() - stackitem.EncodeBinary(si, w) -} - -// DecodeBinary implements io.Serializable interface. -func (s *NEOBalanceState) DecodeBinary(r *io.BinReader) { - si := stackitem.DecodeBinary(r) - if r.Err != nil { - return - } - r.Err = s.FromStackItem(si) -} - -func (s *NEOBalanceState) toStackItem() stackitem.Item { - result := s.NEP17BalanceState.toStackItem().(*stackitem.Struct) +// ToStackItem implements stackitem.Convertible interface. It never returns an error. +func (s *NEOBalanceState) ToStackItem() (stackitem.Item, error) { + resItem, _ := s.NEP17BalanceState.ToStackItem() + result := resItem.(*stackitem.Struct) result.Append(stackitem.NewBigInteger(big.NewInt(int64(s.BalanceHeight)))) if s.VoteTo != nil { result.Append(stackitem.NewByteArray(s.VoteTo.Bytes())) } else { result.Append(stackitem.Null{}) } - return result + return result, nil } // FromStackItem converts stackitem.Item to NEOBalanceState. diff --git a/pkg/core/state/notification_event.go b/pkg/core/state/notification_event.go index 87ea9dbb1..5c7962ca0 100644 --- a/pkg/core/state/notification_event.go +++ b/pkg/core/state/notification_event.go @@ -72,7 +72,7 @@ func (aer *AppExecResult) DecodeBinary(r *io.BinReader) { aer.VMState = vm.State(r.ReadB()) aer.GasConsumed = int64(r.ReadU64LE()) sz := r.ReadVarUint() - if stackitem.MaxArraySize < sz && r.Err == nil { + if stackitem.MaxDeserialized < sz && r.Err == nil { r.Err = errors.New("invalid format") } if r.Err != nil { diff --git a/pkg/core/state/oracle.go b/pkg/core/state/oracle.go index 96efde901..861a5a1db 100644 --- a/pkg/core/state/oracle.go +++ b/pkg/core/state/oracle.go @@ -5,7 +5,6 @@ import ( "math/big" "unicode/utf8" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -21,28 +20,9 @@ type OracleRequest struct { UserData []byte } -// Bytes return o serizalized to a byte-slice. -func (o *OracleRequest) Bytes() []byte { - w := io.NewBufBinWriter() - o.EncodeBinary(w.BinWriter) - return w.Bytes() -} - -// EncodeBinary implements io.Serializable. -func (o *OracleRequest) EncodeBinary(w *io.BinWriter) { - stackitem.EncodeBinary(o.toStackItem(), w) -} - -// DecodeBinary implements io.Serializable. -func (o *OracleRequest) DecodeBinary(r *io.BinReader) { - item := stackitem.DecodeBinary(r) - if r.Err != nil || item == nil { - return - } - r.Err = o.fromStackItem(item) -} - -func (o *OracleRequest) toStackItem() stackitem.Item { +// ToStackItem implements stackitem.Convertible interface. It never returns an +// error. +func (o *OracleRequest) ToStackItem() (stackitem.Item, error) { filter := stackitem.Item(stackitem.Null{}) if o.Filter != nil { filter = stackitem.Make(*o.Filter) @@ -55,10 +35,11 @@ func (o *OracleRequest) toStackItem() stackitem.Item { stackitem.NewByteArray(o.CallbackContract.BytesBE()), stackitem.Make(o.CallbackMethod), stackitem.NewByteArray(o.UserData), - }) + }), nil } -func (o *OracleRequest) fromStackItem(it stackitem.Item) error { +// FromStackItem implements stackitem.Convertible interface. +func (o *OracleRequest) FromStackItem(it stackitem.Item) error { arr, ok := it.Value().([]stackitem.Item) if !ok || len(arr) < 7 { return errors.New("not an array of needed length") diff --git a/pkg/core/state/oracle_test.go b/pkg/core/state/oracle_test.go index bb17ee6bc..07e2ce1f6 100644 --- a/pkg/core/state/oracle_test.go +++ b/pkg/core/state/oracle_test.go @@ -6,12 +6,11 @@ import ( "github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/internal/testserdes" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/require" ) -func TestOracleRequest_EncodeBinary(t *testing.T) { +func TestOracleRequestToFromSI(t *testing.T) { t.Run("Valid", func(t *testing.T) { r := &OracleRequest{ OriginalTxID: random.Uint256(), @@ -21,25 +20,19 @@ func TestOracleRequest_EncodeBinary(t *testing.T) { CallbackMethod: "method", UserData: []byte{1, 2, 3}, } - testserdes.EncodeDecodeBinary(t, r, new(OracleRequest)) + testserdes.ToFromStackItem(t, r, new(OracleRequest)) t.Run("WithFilter", func(t *testing.T) { s := "filter" r.Filter = &s - testserdes.EncodeDecodeBinary(t, r, new(OracleRequest)) + testserdes.ToFromStackItem(t, r, new(OracleRequest)) }) }) t.Run("Invalid", func(t *testing.T) { - w := io.NewBufBinWriter() + var res = new(OracleRequest) t.Run("NotArray", func(t *testing.T) { - w.Reset() it := stackitem.NewByteArray([]byte{}) - stackitem.EncodeBinary(it, w.BinWriter) - require.Error(t, testserdes.DecodeBinary(w.Bytes(), new(OracleRequest))) - }) - t.Run("NotStackItem", func(t *testing.T) { - w.Reset() - require.Error(t, testserdes.DecodeBinary([]byte{0x77}, new(OracleRequest))) + require.Error(t, res.FromStackItem(it)) }) items := []stackitem.Item{ @@ -54,12 +47,10 @@ func TestOracleRequest_EncodeBinary(t *testing.T) { arrItem := stackitem.NewArray(items) runInvalid := func(i int, elem stackitem.Item) func(t *testing.T) { return func(t *testing.T) { - w.Reset() before := items[i] items[i] = elem - stackitem.EncodeBinary(arrItem, w.BinWriter) + require.Error(t, res.FromStackItem(arrItem)) items[i] = before - require.Error(t, testserdes.DecodeBinary(w.Bytes(), new(OracleRequest))) } } t.Run("TxID", func(t *testing.T) { diff --git a/pkg/util/uint160_test.go b/pkg/util/uint160_test.go index 7445242b1..f6bc855a2 100644 --- a/pkg/util/uint160_test.go +++ b/pkg/util/uint160_test.go @@ -1,21 +1,22 @@ -package util +package util_test import ( "encoding/hex" "testing" "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestUint160UnmarshalJSON(t *testing.T) { str := "0263c1de100292813b5e075e585acc1bae963b2d" - expected, err := Uint160DecodeStringLE(str) + expected, err := util.Uint160DecodeStringLE(str) assert.NoError(t, err) // UnmarshalJSON decodes hex-strings - var u1, u2 Uint160 + var u1, u2 util.Uint160 assert.NoError(t, u1.UnmarshalJSON([]byte(`"`+str+`"`))) assert.True(t, expected.Equals(u1)) @@ -27,25 +28,25 @@ func TestUint160UnmarshalJSON(t *testing.T) { func TestUInt160DecodeString(t *testing.T) { hexStr := "2d3b96ae1bcc5a585e075e3b81920210dec16302" - val, err := Uint160DecodeStringBE(hexStr) + val, err := util.Uint160DecodeStringBE(hexStr) assert.NoError(t, err) assert.Equal(t, hexStr, val.String()) - valLE, err := Uint160DecodeStringLE(hexStr) + valLE, err := util.Uint160DecodeStringLE(hexStr) assert.NoError(t, err) assert.Equal(t, val, valLE.Reverse()) - _, err = Uint160DecodeStringBE(hexStr[1:]) + _, err = util.Uint160DecodeStringBE(hexStr[1:]) assert.Error(t, err) - _, err = Uint160DecodeStringLE(hexStr[1:]) + _, err = util.Uint160DecodeStringLE(hexStr[1:]) assert.Error(t, err) hexStr = "zz3b96ae1bcc5a585e075e3b81920210dec16302" - _, err = Uint160DecodeStringBE(hexStr) + _, err = util.Uint160DecodeStringBE(hexStr) assert.Error(t, err) - _, err = Uint160DecodeStringLE(hexStr) + _, err = util.Uint160DecodeStringLE(hexStr) assert.Error(t, err) } @@ -54,18 +55,18 @@ func TestUint160DecodeBytes(t *testing.T) { b, err := hex.DecodeString(hexStr) require.NoError(t, err) - val, err := Uint160DecodeBytesBE(b) + val, err := util.Uint160DecodeBytesBE(b) assert.NoError(t, err) assert.Equal(t, hexStr, val.String()) - valLE, err := Uint160DecodeBytesLE(b) + valLE, err := util.Uint160DecodeBytesLE(b) assert.NoError(t, err) assert.Equal(t, val, valLE.Reverse()) - _, err = Uint160DecodeBytesLE(b[1:]) + _, err = util.Uint160DecodeBytesLE(b[1:]) assert.Error(t, err) - _, err = Uint160DecodeBytesBE(b[1:]) + _, err = util.Uint160DecodeBytesBE(b[1:]) assert.Error(t, err) } @@ -73,10 +74,10 @@ func TestUInt160Equals(t *testing.T) { a := "2d3b96ae1bcc5a585e075e3b81920210dec16302" b := "4d3b96ae1bcc5a585e075e3b81920210dec16302" - ua, err := Uint160DecodeStringBE(a) + ua, err := util.Uint160DecodeStringBE(a) require.NoError(t, err) - ub, err := Uint160DecodeStringBE(b) + ub, err := util.Uint160DecodeStringBE(b) require.NoError(t, err) assert.False(t, ua.Equals(ub), "%s and %s cannot be equal", ua, ub) assert.True(t, ua.Equals(ua), "%s and %s must be equal", ua, ua) @@ -86,11 +87,11 @@ func TestUInt160Less(t *testing.T) { a := "2d3b96ae1bcc5a585e075e3b81920210dec16302" b := "2d3b96ae1bcc5a585e075e3b81920210dec16303" - ua, err := Uint160DecodeStringBE(a) + ua, err := util.Uint160DecodeStringBE(a) assert.Nil(t, err) - ua2, err := Uint160DecodeStringBE(a) + ua2, err := util.Uint160DecodeStringBE(a) assert.Nil(t, err) - ub, err := Uint160DecodeStringBE(b) + ub, err := util.Uint160DecodeStringBE(b) assert.Nil(t, err) assert.Equal(t, true, ua.Less(ub)) assert.Equal(t, false, ua.Less(ua2)) @@ -101,7 +102,7 @@ func TestUInt160String(t *testing.T) { hexStr := "b28427088a3729b2536d10122960394e8be6721f" hexRevStr := "1f72e68b4e39602912106d53b229378a082784b2" - val, err := Uint160DecodeStringBE(hexStr) + val, err := util.Uint160DecodeStringBE(hexStr) assert.Nil(t, err) assert.Equal(t, hexStr, val.String()) @@ -110,7 +111,7 @@ func TestUInt160String(t *testing.T) { func TestUint160_Reverse(t *testing.T) { hexStr := "b28427088a3729b2536d10122960394e8be6721f" - val, err := Uint160DecodeStringBE(hexStr) + val, err := util.Uint160DecodeStringBE(hexStr) require.NoError(t, err) assert.Equal(t, hexStr, val.Reverse().StringLE()) diff --git a/pkg/util/uint256_test.go b/pkg/util/uint256_test.go index e0e3d2291..b1f5d2c3e 100644 --- a/pkg/util/uint256_test.go +++ b/pkg/util/uint256_test.go @@ -1,21 +1,22 @@ -package util +package util_test import ( "encoding/hex" "testing" "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestUint256UnmarshalJSON(t *testing.T) { str := "f037308fa0ab18155bccfc08485468c112409ea5064595699e98c545f245f32d" - expected, err := Uint256DecodeStringLE(str) + expected, err := util.Uint256DecodeStringLE(str) require.NoError(t, err) // UnmarshalJSON decodes hex-strings - var u1, u2 Uint256 + var u1, u2 util.Uint256 require.NoError(t, u1.UnmarshalJSON([]byte(`"`+str+`"`))) assert.True(t, expected.Equals(u1)) @@ -28,33 +29,33 @@ func TestUint256UnmarshalJSON(t *testing.T) { func TestUint256DecodeString(t *testing.T) { hexStr := "f037308fa0ab18155bccfc08485468c112409ea5064595699e98c545f245f32d" - val, err := Uint256DecodeStringLE(hexStr) + val, err := util.Uint256DecodeStringLE(hexStr) require.NoError(t, err) assert.Equal(t, hexStr, val.StringLE()) - valBE, err := Uint256DecodeStringBE(hexStr) + valBE, err := util.Uint256DecodeStringBE(hexStr) require.NoError(t, err) assert.Equal(t, val, valBE.Reverse()) bs, err := hex.DecodeString(hexStr) require.NoError(t, err) - val1, err := Uint256DecodeBytesBE(bs) + val1, err := util.Uint256DecodeBytesBE(bs) assert.NoError(t, err) assert.Equal(t, hexStr, val1.String()) assert.Equal(t, val, val1.Reverse()) - _, err = Uint256DecodeStringLE(hexStr[1:]) + _, err = util.Uint256DecodeStringLE(hexStr[1:]) assert.Error(t, err) - _, err = Uint256DecodeStringBE(hexStr[1:]) + _, err = util.Uint256DecodeStringBE(hexStr[1:]) assert.Error(t, err) hexStr = "zzz7308fa0ab18155bccfc08485468c112409ea5064595699e98c545f245f32d" - _, err = Uint256DecodeStringLE(hexStr) + _, err = util.Uint256DecodeStringLE(hexStr) assert.Error(t, err) - _, err = Uint256DecodeStringBE(hexStr) + _, err = util.Uint256DecodeStringBE(hexStr) assert.Error(t, err) } @@ -63,11 +64,11 @@ func TestUint256DecodeBytes(t *testing.T) { b, err := hex.DecodeString(hexStr) require.NoError(t, err) - val, err := Uint256DecodeBytesLE(b) + val, err := util.Uint256DecodeBytesLE(b) require.NoError(t, err) assert.Equal(t, hexStr, val.StringLE()) - _, err = Uint256DecodeBytesBE(b[1:]) + _, err = util.Uint256DecodeBytesBE(b[1:]) assert.Error(t, err) } @@ -75,10 +76,10 @@ func TestUInt256Equals(t *testing.T) { a := "f037308fa0ab18155bccfc08485468c112409ea5064595699e98c545f245f32d" b := "e287c5b29a1b66092be6803c59c765308ac20287e1b4977fd399da5fc8f66ab5" - ua, err := Uint256DecodeStringLE(a) + ua, err := util.Uint256DecodeStringLE(a) require.NoError(t, err) - ub, err := Uint256DecodeStringLE(b) + ub, err := util.Uint256DecodeStringLE(b) require.NoError(t, err) assert.False(t, ua.Equals(ub), "%s and %s cannot be equal", ua, ub) assert.True(t, ua.Equals(ua), "%s and %s must be equal", ua, ua) @@ -86,11 +87,11 @@ func TestUInt256Equals(t *testing.T) { } func TestUint256_Serializable(t *testing.T) { - a := Uint256{ + a := util.Uint256{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, } - var b Uint256 + var b util.Uint256 testserdes.EncodeDecodeBinary(t, &a, &b) } diff --git a/pkg/vm/contract_checks.go b/pkg/vm/contract_checks.go index 9841ffb5b..40b7d6289 100644 --- a/pkg/vm/contract_checks.go +++ b/pkg/vm/contract_checks.go @@ -12,6 +12,9 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) +// MaxMultisigKeys is the maximum number of used keys for correct multisig contract. +const MaxMultisigKeys = 1024 + var ( verifyInteropID = interopnames.ToID([]byte(interopnames.SystemCryptoCheckSig)) multisigInteropID = interopnames.ToID([]byte(interopnames.SystemCryptoCheckMultisig)) @@ -25,14 +28,14 @@ func getNumOfThingsFromInstr(instr opcode.Opcode, param []byte) (int, bool) { nthings = int(instr-opcode.PUSH1) + 1 case instr <= opcode.PUSHINT256: n := bigint.FromBytes(param) - if !n.IsInt64() || n.Int64() > stackitem.MaxArraySize { + if !n.IsInt64() || n.Sign() < 0 || n.Int64() > MaxMultisigKeys { return 0, false } nthings = int(n.Int64()) default: return 0, false } - if nthings < 1 || nthings > stackitem.MaxArraySize { + if nthings < 1 || nthings > MaxMultisigKeys { return 0, false } return nthings, true @@ -76,7 +79,7 @@ func ParseMultiSigContract(script []byte) (int, [][]byte, bool) { } pubs = append(pubs, param) nkeys++ - if nkeys > stackitem.MaxArraySize { + if nkeys > MaxMultisigKeys { return nsigs, nil, false } } diff --git a/pkg/vm/stackitem/item.go b/pkg/vm/stackitem/item.go index c5d7007fe..3a156b22e 100644 --- a/pkg/vm/stackitem/item.go +++ b/pkg/vm/stackitem/item.go @@ -20,10 +20,12 @@ import ( const ( // MaxBigIntegerSizeBits is the maximum size of BigInt item in bits. MaxBigIntegerSizeBits = 32 * 8 - // MaxArraySize is the maximum array size allowed in the VM. - MaxArraySize = 1024 // MaxSize is the maximum item size allowed in the VM. MaxSize = 1024 * 1024 + // MaxComparableNumOfItems is the maximum number of items that can be compared for structs. + MaxComparableNumOfItems = MaxDeserialized + // MaxClonableNumOfItems is the maximum number of items that can be cloned in structs. + MaxClonableNumOfItems = MaxDeserialized // MaxByteArrayComparableSize is the maximum allowed length of ByteArray for Equals method. // It is set to be the maximum uint16 value. MaxByteArrayComparableSize = math.MaxUint16 @@ -52,6 +54,12 @@ type Item interface { Convert(Type) (Item, error) } +// Convertible is something that can be converted to/from Item. +type Convertible interface { + ToStackItem() (Item, error) + FromStackItem(Item) error +} + var ( // ErrInvalidConversion is returned on attempt to make an incorrect // conversion between item types. @@ -67,6 +75,7 @@ var ( errTooBigInteger = fmt.Errorf("%w: integer", ErrTooBig) errTooBigKey = fmt.Errorf("%w: map key", ErrTooBig) errTooBigSize = fmt.Errorf("%w: size", ErrTooBig) + errTooBigElements = fmt.Errorf("%w: many elements", ErrTooBig) ) // mkInvConversion creates conversion error with additional metadata (from and @@ -260,17 +269,35 @@ func (i *Struct) TryInteger() (*big.Int, error) { // Equals implements Item interface. func (i *Struct) Equals(s Item) bool { - if i == s { - return true - } else if s == nil { + if s == nil { return false } val, ok := s.(*Struct) - if !ok || len(i.value) != len(val.value) { + if !ok { + return false + } + var limit = MaxComparableNumOfItems - 1 // 1 for current element. + return i.equalStruct(val, &limit) +} + +func (i *Struct) equalStruct(s *Struct, limit *int) bool { + if i == s { + return true + } else if len(i.value) != len(s.value) { return false } for j := range i.value { - if !i.value[j].Equals(val.value[j]) { + *limit-- + if *limit == 0 { + panic(errTooBigElements) + } + sa, oka := i.value[j].(*Struct) + sb, okb := s.value[j].(*Struct) + if oka && okb { + if !sa.equalStruct(sb, limit) { + return false + } + } else if !i.value[j].Equals(s.value[j]) { return false } } @@ -298,13 +325,18 @@ func (i *Struct) Convert(typ Type) (Item, error) { // Clone returns a Struct with all Struct fields copied by value. // Array fields are still copied by reference. -func (i *Struct) Clone(limit int) (*Struct, error) { +func (i *Struct) Clone() (*Struct, error) { + var limit = MaxClonableNumOfItems - 1 // For this struct itself. return i.clone(&limit) } func (i *Struct) clone(limit *int) (*Struct, error) { ret := &Struct{make([]Item, len(i.value))} for j := range i.value { + *limit-- + if *limit < 0 { + return nil, ErrTooBig + } switch t := i.value[j].(type) { case *Struct: var err error @@ -313,13 +345,9 @@ func (i *Struct) clone(limit *int) (*Struct, error) { if err != nil { return nil, err } - *limit-- default: ret.value[j] = t } - if *limit < 0 { - return nil, ErrTooBig - } } return ret, nil } diff --git a/pkg/vm/stackitem/item_test.go b/pkg/vm/stackitem/item_test.go index 1244074a0..c114229ce 100644 --- a/pkg/vm/stackitem/item_test.go +++ b/pkg/vm/stackitem/item_test.go @@ -172,6 +172,11 @@ var equalsTestCases = map[string][]struct { item2: NewStruct([]Item{NewBigInteger(big.NewInt(1))}), result: true, }, + { + item1: NewStruct([]Item{NewBigInteger(big.NewInt(1)), NewStruct([]Item{})}), + item2: NewStruct([]Item{NewBigInteger(big.NewInt(1)), NewStruct([]Item{})}), + result: true, + }, }, "bigint": { { @@ -381,6 +386,40 @@ func TestEquals(t *testing.T) { } } +func TestEqualsDeepStructure(t *testing.T) { + const perStruct = 4 + var items = []Item{} + var num int + for i := 0; i < perStruct; i++ { + items = append(items, Make(0)) + num++ + } + var layerUp = func(sa *Struct, num int) (*Struct, int) { + items := []Item{} + for i := 0; i < perStruct; i++ { + clon, err := sa.Clone() + require.NoError(t, err) + items = append(items, clon) + } + num *= perStruct + num++ + return NewStruct(items), num + } + var sa = NewStruct(items) + for i := 0; i < 4; i++ { + sa, num = layerUp(sa, num) + } + require.Less(t, num, MaxComparableNumOfItems) + sb, err := sa.Clone() + require.NoError(t, err) + require.True(t, sa.Equals(sb)) + sa, num = layerUp(sa, num) + sb, num = layerUp(sb, num) + + require.Less(t, MaxComparableNumOfItems, num) + require.Panics(t, func() { sa.Equals(sb) }) +} + var marshalJSONTestCases = []struct { input Item result []byte @@ -468,9 +507,12 @@ func TestNewVeryBigInteger(t *testing.T) { func TestStructClone(t *testing.T) { st0 := Struct{} st := Struct{value: []Item{&st0}} - _, err := st.Clone(1) - require.NoError(t, err) - _, err = st.Clone(0) + for i := 0; i < MaxClonableNumOfItems-1; i++ { + nst, err := st.Clone() + require.NoError(t, err) + st = Struct{value: []Item{nst}} + } + _, err := st.Clone() require.Error(t, err) } diff --git a/pkg/vm/stackitem/serialization.go b/pkg/vm/stackitem/serialization.go index 4fb5c2b56..b9a814e46 100644 --- a/pkg/vm/stackitem/serialization.go +++ b/pkg/vm/stackitem/serialization.go @@ -9,6 +9,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/io" ) +// MaxDeserialized is the maximum number one deserialized item can contain +// (including itself). +const MaxDeserialized = 2048 + // ErrRecursive is returned on attempts to serialize some recursive stack item // (like array including an item with reference to the same array). var ErrRecursive = errors.New("recursive item") @@ -25,6 +29,13 @@ type serContext struct { seen map[Item]sliceNoPointer } +// deserContext is an internal deserialization context. +type deserContext struct { + *io.BinReader + allowInvalid bool + limit int +} + // Serialize encodes given Item into the byte slice. func Serialize(item Item) ([]byte, error) { sc := serContext{ @@ -179,21 +190,36 @@ func Deserialize(data []byte) (Item, error) { // as a function because Item itself is an interface. Caveat: always check // reader's error value before using the returned Item. func DecodeBinary(r *io.BinReader) Item { - return decodeBinary(r, false) + dc := deserContext{ + BinReader: r, + allowInvalid: false, + limit: MaxDeserialized, + } + return dc.decodeBinary() } // DecodeBinaryProtected is similar to DecodeBinary but allows Interop and // Invalid values to be present (making it symmetric to EncodeBinaryProtected). func DecodeBinaryProtected(r *io.BinReader) Item { - return decodeBinary(r, true) + dc := deserContext{ + BinReader: r, + allowInvalid: true, + limit: MaxDeserialized, + } + return dc.decodeBinary() } -func decodeBinary(r *io.BinReader, allowInvalid bool) Item { +func (r *deserContext) decodeBinary() Item { var t = Type(r.ReadB()) if r.Err != nil { return nil } + r.limit-- + if r.limit < 0 { + r.Err = errTooBigElements + return nil + } switch t { case ByteArrayT, BufferT: data := r.ReadVarBytes(MaxSize) @@ -210,9 +236,13 @@ func decodeBinary(r *io.BinReader, allowInvalid bool) Item { return NewBigInteger(num) case ArrayT, StructT: size := int(r.ReadVarUint()) + if size > MaxDeserialized { + r.Err = errTooBigElements + return nil + } arr := make([]Item, size) for i := 0; i < size; i++ { - arr[i] = DecodeBinary(r) + arr[i] = r.decodeBinary() } if t == ArrayT { @@ -221,10 +251,14 @@ func decodeBinary(r *io.BinReader, allowInvalid bool) Item { return NewStruct(arr) case MapT: size := int(r.ReadVarUint()) + if size > MaxDeserialized { + r.Err = errTooBigElements + return nil + } m := NewMap() for i := 0; i < size; i++ { - key := DecodeBinary(r) - value := DecodeBinary(r) + key := r.decodeBinary() + value := r.decodeBinary() if r.Err != nil { break } @@ -234,15 +268,33 @@ func decodeBinary(r *io.BinReader, allowInvalid bool) Item { case AnyT: return Null{} case InteropT: - if allowInvalid { + if r.allowInvalid { return NewInterop(nil) } fallthrough default: - if t == InvalidT && allowInvalid { + if t == InvalidT && r.allowInvalid { return nil } r.Err = fmt.Errorf("%w: %v", ErrInvalidType, t) return nil } } + +// SerializeConvertible serializes Convertible into a slice of bytes. +func SerializeConvertible(conv Convertible) ([]byte, error) { + item, err := conv.ToStackItem() + if err != nil { + return nil, err + } + return Serialize(item) +} + +// DeserializeConvertible deserializes Convertible from a slice of bytes. +func DeserializeConvertible(data []byte, conv Convertible) error { + item, err := Deserialize(data) + if err != nil { + return err + } + return conv.FromStackItem(item) +} diff --git a/pkg/vm/stackitem/serialization_test.go b/pkg/vm/stackitem/serialization_test.go index b5adff651..139fbcbcf 100644 --- a/pkg/vm/stackitem/serialization_test.go +++ b/pkg/vm/stackitem/serialization_test.go @@ -39,6 +39,7 @@ func testSerialize(t *testing.T, expectedErr error, item Item) { func TestSerialize(t *testing.T) { bigByteArray := NewByteArray(make([]byte, MaxSize/2)) smallByteArray := NewByteArray(make([]byte, MaxSize/4)) + zeroByteArray := NewByteArray(make([]byte, 0)) testArray := func(t *testing.T, newItem func([]Item) Item) { arr := newItem([]Item{bigByteArray}) testSerialize(t, nil, arr) @@ -50,6 +51,18 @@ func TestSerialize(t *testing.T) { arr.Value().([]Item)[0] = arr testSerialize(t, ErrRecursive, arr) + + items := make([]Item, 0, MaxDeserialized-1) + for i := 0; i < MaxDeserialized-1; i++ { + items = append(items, zeroByteArray) + } + testSerialize(t, nil, newItem(items)) + + items = append(items, zeroByteArray) + data, err := Serialize(newItem(items)) + require.NoError(t, err) + _, err = Deserialize(data) + require.True(t, errors.Is(err, ErrTooBig), err) } t.Run("array", func(t *testing.T) { testArray(t, func(items []Item) Item { return NewArray(items) }) @@ -126,9 +139,59 @@ func TestSerialize(t *testing.T) { m.Add(Make(0), NewByteArray(make([]byte, MaxSize-MaxKeySize))) m.Add(NewByteArray(make([]byte, MaxKeySize)), Make(1)) testSerialize(t, ErrTooBig, m) + + m = NewMap() + for i := 0; i < MaxDeserialized/2-1; i++ { + m.Add(Make(i), zeroByteArray) + } + testSerialize(t, nil, m) + + for i := 0; i <= MaxDeserialized; i++ { + m.Add(Make(i), zeroByteArray) + } + data, err := Serialize(m) + require.NoError(t, err) + _, err = Deserialize(data) + require.True(t, errors.Is(err, ErrTooBig), err) }) } +func TestEmptyDeserialization(t *testing.T) { + empty := []byte{} + _, err := Deserialize(empty) + require.Error(t, err) +} + +func TestMapDeserializationError(t *testing.T) { + m := NewMap() + m.Add(Make(1), Make(1)) + m.Add(Make(2), nil) // Bad value + m.Add(Make(3), Make(3)) + + w := io.NewBufBinWriter() + EncodeBinaryProtected(m, w.BinWriter) + require.NoError(t, w.Err) + _, err := Deserialize(w.Bytes()) + require.True(t, errors.Is(err, ErrInvalidType), err) +} + +func TestDeserializeTooManyElements(t *testing.T) { + item := Make(0) + for i := 0; i < MaxDeserialized-1; i++ { // 1 for zero inner element. + item = Make([]Item{item}) + } + data, err := Serialize(item) + require.NoError(t, err) + _, err = Deserialize(data) + require.NoError(t, err) + + item = Make([]Item{item}) + data, err = Serialize(item) + require.NoError(t, err) + _, err = Deserialize(data) + require.True(t, errors.Is(err, ErrTooBig), err) +} + func BenchmarkEncodeBinary(b *testing.B) { arr := getBigArray(15) diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 7ae6d5cdf..9d996d7f8 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -918,7 +918,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.POW: exp := v.estack.Pop().BigInt() a := v.estack.Pop().BigInt() - if ei := exp.Int64(); !exp.IsInt64() || ei > math.MaxInt32 || ei < 0 { + if ei := exp.Uint64(); !exp.IsUint64() || ei > maxSHLArg { panic("invalid exponent") } v.estack.PushVal(new(big.Int).Exp(a, exp, nil)) @@ -1027,47 +1027,37 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.NEWARRAY0: v.estack.PushVal(stackitem.NewArray([]stackitem.Item{})) - case opcode.NEWARRAY, opcode.NEWARRAYT: - item := v.estack.Pop() - n := item.BigInt().Int64() - if n > stackitem.MaxArraySize { - panic("too long array") + case opcode.NEWARRAY, opcode.NEWARRAYT, opcode.NEWSTRUCT: + n := toInt(v.estack.Pop().BigInt()) + if n < 0 || n > MaxStackSize { + panic("wrong number of elements") } typ := stackitem.AnyT if op == opcode.NEWARRAYT { typ = stackitem.Type(parameter[0]) } items := makeArrayOfType(int(n), typ) - v.estack.PushVal(stackitem.NewArray(items)) + var res stackitem.Item + if op == opcode.NEWSTRUCT { + res = stackitem.NewStruct(items) + } else { + res = stackitem.NewArray(items) + } + v.estack.PushVal(res) case opcode.NEWSTRUCT0: v.estack.PushVal(stackitem.NewStruct([]stackitem.Item{})) - case opcode.NEWSTRUCT: - item := v.estack.Pop() - n := item.BigInt().Int64() - if n > stackitem.MaxArraySize { - panic("too long struct") - } - items := makeArrayOfType(int(n), stackitem.AnyT) - v.estack.PushVal(stackitem.NewStruct(items)) - case opcode.APPEND: itemElem := v.estack.Pop() arrElem := v.estack.Pop() - val := cloneIfStruct(itemElem.value, MaxStackSize-v.refs.size) + val := cloneIfStruct(itemElem.value) switch t := arrElem.value.(type) { case *stackitem.Array: - if t.Len() >= stackitem.MaxArraySize { - panic("too long array") - } t.Append(val) case *stackitem.Struct: - if t.Len() >= stackitem.MaxArraySize { - panic("too long struct") - } t.Append(val) default: panic("APPEND: not of underlying type Array") @@ -1076,8 +1066,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro v.refs.Add(val) case opcode.PACK: - n := int(v.estack.Pop().BigInt().Int64()) - if n < 0 || n > v.estack.Len() || n > stackitem.MaxArraySize { + n := toInt(v.estack.Pop().BigInt()) + if n < 0 || n > v.estack.Len() { panic("OPACK: invalid length") } @@ -1148,8 +1138,6 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case *stackitem.Map: if i := t.Index(key.value); i >= 0 { v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value) - } else if t.Len() >= stackitem.MaxArraySize { - panic("too big map") } t.Add(key.value, item) v.refs.Add(item) @@ -1370,12 +1358,12 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro src := t.Value().([]stackitem.Item) arr = make([]stackitem.Item, len(src)) for i := range src { - arr[i] = cloneIfStruct(src[i], MaxStackSize-v.refs.size) + arr[i] = cloneIfStruct(src[i]) } case *stackitem.Map: arr = make([]stackitem.Item, 0, t.Len()) for k := range t.Value().([]stackitem.MapElement) { - arr = append(arr, cloneIfStruct(t.Value().([]stackitem.MapElement)[k].Value, MaxStackSize-v.refs.size)) + arr = append(arr, cloneIfStruct(t.Value().([]stackitem.MapElement)[k].Value)) } default: panic("not a Map, Array or Struct") @@ -1741,10 +1729,10 @@ func checkMultisig1(v *VM, curve elliptic.Curve, h []byte, pkeys [][]byte, sig [ return false } -func cloneIfStruct(item stackitem.Item, limit int) stackitem.Item { +func cloneIfStruct(item stackitem.Item) stackitem.Item { switch it := item.(type) { case *stackitem.Struct: - ret, err := it.Clone(limit) + ret, err := it.Clone() if err != nil { panic(err) } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 484b3cda1..8fb07646f 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "errors" "fmt" - "math" "math/big" "math/rand" "testing" @@ -719,7 +718,7 @@ func TestPOW(t *testing.T) { t.Run("good, negative, odd", getTestFuncForVM(prog, -8, -2, 3)) t.Run("zero", getTestFuncForVM(prog, 1, 3, 0)) t.Run("negative exponent", getTestFuncForVM(prog, nil, 3, -1)) - t.Run("too big exponent", getTestFuncForVM(prog, nil, 1, math.MaxInt32+1)) + t.Run("too big exponent", getTestFuncForVM(prog, nil, 1, maxSHLArg+1)) } func TestSQRT(t *testing.T) { @@ -1058,7 +1057,7 @@ func TestNEWSTRUCT0(t *testing.T) { func TestNEWARRAYArray(t *testing.T) { prog := makeProgram(opcode.NEWARRAY) t.Run("ByteArray", getTestFuncForVM(prog, stackitem.NewArray([]stackitem.Item{}), []byte{})) - t.Run("BadSize", getTestFuncForVM(prog, nil, stackitem.MaxArraySize+1)) + t.Run("BadSize", getTestFuncForVM(prog, nil, MaxStackSize+1)) t.Run("Integer", getTestFuncForVM(prog, []stackitem.Item{stackitem.Null{}}, 1)) } @@ -1109,7 +1108,7 @@ func TestNEWARRAYT(t *testing.T) { func TestNEWSTRUCT(t *testing.T) { prog := makeProgram(opcode.NEWSTRUCT) t.Run("ByteArray", getTestFuncForVM(prog, stackitem.NewStruct([]stackitem.Item{}), []byte{})) - t.Run("BadSize", getTestFuncForVM(prog, nil, stackitem.MaxArraySize+1)) + t.Run("BadSize", getTestFuncForVM(prog, nil, MaxStackSize+1)) t.Run("Integer", getTestFuncForVM(prog, stackitem.NewStruct([]stackitem.Item{stackitem.Null{}}), 1)) } @@ -1137,15 +1136,20 @@ func TestAPPENDBad(t *testing.T) { func TestAPPENDGoodSizeLimit(t *testing.T) { prog := makeProgram(opcode.NEWARRAY, opcode.DUP, opcode.PUSH0, opcode.APPEND) vm := load(prog) - vm.estack.PushVal(stackitem.MaxArraySize - 1) + vm.estack.PushVal(MaxStackSize - 3) // 1 for array, 1 for copy, 1 for pushed 0. runVM(t, vm) assert.Equal(t, 1, vm.estack.Len()) - assert.Equal(t, stackitem.MaxArraySize, len(vm.estack.Pop().Array())) + assert.Equal(t, MaxStackSize-2, len(vm.estack.Pop().Array())) } func TestAPPENDBadSizeLimit(t *testing.T) { prog := makeProgram(opcode.NEWARRAY, opcode.DUP, opcode.PUSH0, opcode.APPEND) - runWithArgs(t, prog, nil, stackitem.MaxArraySize) + runWithArgs(t, prog, nil, MaxStackSize) +} + +func TestAPPENDRefSizeLimit(t *testing.T) { + prog := makeProgram(opcode.NEWARRAY0, opcode.DUP, opcode.DUP, opcode.APPEND, opcode.JMP, 0xfd) + runWithArgs(t, prog, nil) } func TestPICKITEM(t *testing.T) { @@ -1206,19 +1210,19 @@ func TestSETITEMMap(t *testing.T) { func TestSETITEMBigMapBad(t *testing.T) { prog := makeProgram(opcode.SETITEM) m := stackitem.NewMap() - for i := 0; i < stackitem.MaxArraySize; i++ { + for i := 0; i < MaxStackSize; i++ { m.Add(stackitem.Make(i), stackitem.Make(i)) } - runWithArgs(t, prog, nil, m, stackitem.MaxArraySize, 0) + runWithArgs(t, prog, nil, m, m, MaxStackSize, 0) } // This test checks is SETITEM properly updates reference counter. -// 1. Create 2 arrays of size MaxArraySize - 3. (MaxStackSize = 2 * MaxArraySize) +// 1. Create 2 arrays of size MaxStackSize/2 - 3. // 2. SETITEM each of them to a map. // 3. Replace each of them with a scalar value. func TestSETITEMMapStackLimit(t *testing.T) { - size := stackitem.MaxArraySize - 3 + size := MaxStackSize/2 - 3 m := stackitem.NewMap() m.Add(stackitem.NewBigInteger(big.NewInt(1)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT))) m.Add(stackitem.NewBigInteger(big.NewInt(2)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT))) @@ -1238,7 +1242,7 @@ func TestSETITEMBigMapGood(t *testing.T) { vm := load(prog) m := stackitem.NewMap() - for i := 0; i < stackitem.MaxArraySize; i++ { + for i := 0; i < MaxStackSize-3; i++ { m.Add(stackitem.Make(i), stackitem.Make(i)) } vm.estack.Push(&Element{value: m}) @@ -1724,16 +1728,6 @@ func TestPACK(t *testing.T) { t.Run("Good0Len", getTestFuncForVM(prog, []stackitem.Item{}, 0)) } -func TestPACKBigLen(t *testing.T) { - prog := makeProgram(opcode.PACK) - vm := load(prog) - for i := 0; i <= stackitem.MaxArraySize; i++ { - vm.estack.PushVal(0) - } - vm.estack.PushVal(stackitem.MaxArraySize + 1) - checkVMFailed(t, vm) -} - func TestPACKGood(t *testing.T) { prog := makeProgram(opcode.PACK) elements := []int{55, 34, 42} @@ -1757,7 +1751,7 @@ func TestPACKGood(t *testing.T) { func TestPACK_UNPACK_MaxSize(t *testing.T) { prog := makeProgram(opcode.PACK, opcode.UNPACK) - elements := make([]int, stackitem.MaxArraySize) + elements := make([]int, MaxStackSize-2) vm := load(prog) // canary vm.estack.PushVal(1) @@ -1780,7 +1774,7 @@ func TestPACK_UNPACK_MaxSize(t *testing.T) { func TestPACK_UNPACK_PACK_MaxSize(t *testing.T) { prog := makeProgram(opcode.PACK, opcode.UNPACK, opcode.PACK) - elements := make([]int, stackitem.MaxArraySize) + elements := make([]int, MaxStackSize-2) vm := load(prog) // canary vm.estack.PushVal(1) @@ -2451,6 +2445,14 @@ func TestNestedStructClone(t *testing.T) { } } +func TestNestedStructEquals(t *testing.T) { + h := "560112c501fe0160589d604a12c0db415824f7509d4a102aec4597" // See neo-project/neo-vm#426. + prog, err := hex.DecodeString(h) + require.NoError(t, err) + vm := load(prog) + checkVMFailed(t, vm) +} + func makeProgram(opcodes ...opcode.Opcode) []byte { prog := make([]byte, len(opcodes)+1) // RET for i := 0; i < len(opcodes); i++ {