diff --git a/go.mod b/go.mod index 7e94a0a32..cdf2af233 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/btcsuite/btcd v0.22.0-beta github.com/gorilla/websocket v1.4.2 github.com/hashicorp/golang-lru v0.5.4 + github.com/holiman/uint256 v1.2.0 github.com/mr-tron/base58 v1.2.0 github.com/nspcc-dev/dbft v0.0.0-20210721160347-1b03241391ac github.com/nspcc-dev/go-ordered-json v0.0.0-20210915112629-e1b6cce73d02 diff --git a/go.sum b/go.sum index a38e4b761..7b06b7467 100644 --- a/go.sum +++ b/go.sum @@ -128,6 +128,8 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/holiman/uint256 v1.2.0 h1:gpSYcPLWGv4sG43I2mVLiDZCNDh/EpGjSk8tmtxitHM= +github.com/holiman/uint256 v1.2.0/go.mod h1:y4ga/t+u+Xwd7CpDgZESaRcWy0I7XMlTMA25ApIH5Jw= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 7e1bb98ef..e0790464e 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -3,11 +3,11 @@ package mempool import ( "errors" "fmt" - "math/big" "math/bits" "sort" "sync" + "github.com/holiman/uint256" "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/util" @@ -50,8 +50,8 @@ type items []item // utilityBalanceAndFees stores sender's balance and overall fees of // sender's transactions which are currently in mempool. type utilityBalanceAndFees struct { - balance *big.Int - feeSum *big.Int + balance uint256.Int + feeSum uint256.Int } // Pool stores the unconfirms transactions. @@ -164,8 +164,7 @@ func (mp *Pool) tryAddSendersFee(tx *transaction.Transaction, feer Feer, needChe payer := tx.Signers[mp.payerIndex].Account senderFee, ok := mp.fees[payer] if !ok { - senderFee.balance = feer.GetUtilityTokenBalance(payer) - senderFee.feeSum = big.NewInt(0) + _ = senderFee.balance.SetFromBig(feer.GetUtilityTokenBalance(payer)) mp.fees[payer] = senderFee } if needCheck { @@ -173,23 +172,26 @@ func (mp *Pool) tryAddSendersFee(tx *transaction.Transaction, feer Feer, needChe if err != nil { return false } - senderFee.feeSum.Set(newFeeSum) + senderFee.feeSum = newFeeSum } else { - senderFee.feeSum.Add(senderFee.feeSum, big.NewInt(tx.SystemFee+tx.NetworkFee)) + senderFee.feeSum.AddUint64(&senderFee.feeSum, uint64(tx.SystemFee+tx.NetworkFee)) } + mp.fees[payer] = senderFee return true } // checkBalance returns new cumulative fee balance for account or an error in // case sender doesn't have enough GAS to pay for the transaction. -func checkBalance(tx *transaction.Transaction, balance utilityBalanceAndFees) (*big.Int, error) { - txFee := big.NewInt(tx.SystemFee + tx.NetworkFee) - if balance.balance.Cmp(txFee) < 0 { - return nil, ErrInsufficientFunds +func checkBalance(tx *transaction.Transaction, balance utilityBalanceAndFees) (uint256.Int, error) { + var txFee uint256.Int + + txFee.SetUint64(uint64(tx.SystemFee + tx.NetworkFee)) + if balance.balance.Cmp(&txFee) < 0 { + return txFee, ErrInsufficientFunds } - txFee.Add(txFee, balance.feeSum) - if balance.balance.Cmp(txFee) < 0 { - return nil, ErrConflict + txFee.Add(&txFee, &balance.feeSum) + if balance.balance.Cmp(&txFee) < 0 { + return txFee, ErrConflict } return txFee, nil } @@ -323,7 +325,7 @@ func (mp *Pool) removeInternal(hash util.Uint256, feer Feer) { } payer := itm.txn.Signers[mp.payerIndex].Account senderFee := mp.fees[payer] - senderFee.feeSum.Sub(senderFee.feeSum, big.NewInt(tx.SystemFee+tx.NetworkFee)) + senderFee.feeSum.SubUint64(&senderFee.feeSum, uint64(tx.SystemFee+tx.NetworkFee)) mp.fees[payer] = senderFee if feer.P2PSigExtensionsEnabled() { // remove all conflicting hashes from mp.conflicts list @@ -420,7 +422,7 @@ func (mp *Pool) checkPolicy(tx *transaction.Transaction, policyChanged bool) boo // New returns a new Pool struct. func New(capacity int, payerIndex int, enableSubscriptions bool) *Pool { mp := &Pool{ - verifiedMap: make(map[util.Uint256]*transaction.Transaction), + verifiedMap: make(map[util.Uint256]*transaction.Transaction, capacity), verifiedTxes: make([]item, 0, capacity), capacity: capacity, payerIndex: payerIndex, @@ -507,8 +509,7 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*tran payer := tx.Signers[mp.payerIndex].Account actualSenderFee, ok := mp.fees[payer] if !ok { - actualSenderFee.balance = fee.GetUtilityTokenBalance(payer) - actualSenderFee.feeSum = big.NewInt(0) + actualSenderFee.balance.SetFromBig(fee.GetUtilityTokenBalance(payer)) } var expectedSenderFee utilityBalanceAndFees @@ -541,13 +542,10 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*tran conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx) } // Step 3: take into account sender's conflicting transactions before balance check. - expectedSenderFee = utilityBalanceAndFees{ - balance: new(big.Int).Set(actualSenderFee.balance), - feeSum: new(big.Int).Set(actualSenderFee.feeSum), - } + expectedSenderFee = actualSenderFee for _, conflictingTx := range conflictsToBeRemoved { if conflictingTx.Signers[mp.payerIndex].Account.Equals(payer) { - expectedSenderFee.feeSum.Sub(expectedSenderFee.feeSum, big.NewInt(conflictingTx.SystemFee+conflictingTx.NetworkFee)) + expectedSenderFee.feeSum.SubUint64(&expectedSenderFee.feeSum, uint64(conflictingTx.SystemFee+conflictingTx.NetworkFee)) } } } else { diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index a7e73839e..20d571953 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/holiman/uint256" "github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/payload" @@ -268,8 +269,8 @@ func TestMemPoolFees(t *testing.T) { require.NoError(t, mp.Add(tx1, fs)) require.Equal(t, 1, len(mp.fees)) require.Equal(t, utilityBalanceAndFees{ - balance: big.NewInt(fs.balance), - feeSum: big.NewInt(tx1.NetworkFee), + balance: *uint256.NewInt(uint64(fs.balance)), + feeSum: *uint256.NewInt(uint64(tx1.NetworkFee)), }, mp.fees[sender0]) // balance shouldn't change after adding one more transaction @@ -280,8 +281,8 @@ func TestMemPoolFees(t *testing.T) { require.Equal(t, 2, len(mp.verifiedTxes)) require.Equal(t, 1, len(mp.fees)) require.Equal(t, utilityBalanceAndFees{ - balance: big.NewInt(fs.balance), - feeSum: big.NewInt(fs.balance), + balance: *uint256.NewInt(uint64(fs.balance)), + feeSum: *uint256.NewInt(uint64(fs.balance)), }, mp.fees[sender0]) // can't add more transactions as we don't have enough GAS @@ -292,8 +293,8 @@ func TestMemPoolFees(t *testing.T) { require.Error(t, mp.Add(tx3, fs)) require.Equal(t, 1, len(mp.fees)) require.Equal(t, utilityBalanceAndFees{ - balance: big.NewInt(fs.balance), - feeSum: big.NewInt(fs.balance), + balance: *uint256.NewInt(uint64(fs.balance)), + feeSum: *uint256.NewInt(uint64(fs.balance)), }, mp.fees[sender0]) // check whether sender's fee updates correctly @@ -302,8 +303,8 @@ func TestMemPoolFees(t *testing.T) { }, fs) require.Equal(t, 1, len(mp.fees)) require.Equal(t, utilityBalanceAndFees{ - balance: big.NewInt(fs.balance), - feeSum: big.NewInt(tx2.NetworkFee), + balance: *uint256.NewInt(uint64(fs.balance)), + feeSum: *uint256.NewInt(uint64(tx2.NetworkFee)), }, mp.fees[sender0]) // there should be nothing left diff --git a/pkg/core/native/native_neo.go b/pkg/core/native/native_neo.go index 53604bca3..eb5291d87 100644 --- a/pkg/core/native/native_neo.go +++ b/pkg/core/native/native_neo.go @@ -90,6 +90,12 @@ const ( var ( // prefixCommittee is a key used to store committee. prefixCommittee = []byte{14} + + bigCommitteeRewardRatio = big.NewInt(committeeRewardRatio) + bigVoterRewardRatio = big.NewInt(voterRewardRatio) + bigVoterRewardFactor = big.NewInt(voterRewardFactor) + bigEffectiveVoterTurnout = big.NewInt(effectiveVoterTurnout) + big100 = big.NewInt(100) ) // makeValidatorKey creates a key from account script hash. @@ -316,24 +322,26 @@ func (n *NEO) PostPersist(ic *interop.Context) error { pubs := n.GetCommitteeMembers() committeeSize := len(ic.Chain.GetConfig().StandbyCommittee) index := int(ic.Block.Index) % committeeSize - committeeReward := new(big.Int).Mul(gas, big.NewInt(committeeRewardRatio)) - n.GAS.mint(ic, pubs[index].GetScriptHash(), committeeReward.Div(committeeReward, big.NewInt(100)), false) + committeeReward := new(big.Int).Mul(gas, bigCommitteeRewardRatio) + n.GAS.mint(ic, pubs[index].GetScriptHash(), committeeReward.Div(committeeReward, big100), false) if ShouldUpdateCommittee(ic.Block.Index, ic.Chain) { - var voterReward = big.NewInt(voterRewardRatio) + var voterReward = new(big.Int).Set(bigVoterRewardRatio) voterReward.Mul(voterReward, gas) voterReward.Mul(voterReward, big.NewInt(voterRewardFactor*int64(committeeSize))) var validatorsCount = ic.Chain.GetConfig().ValidatorsCount voterReward.Div(voterReward, big.NewInt(int64(committeeSize+validatorsCount))) - voterReward.Div(voterReward, big.NewInt(100)) + voterReward.Div(voterReward, big100) var cs = n.committee.Load().(keysWithVotes) var key = make([]byte, 38) for i := range cs { if cs[i].Votes.Sign() > 0 { - tmp := big.NewInt(1) + var tmp = new(big.Int) if i < validatorsCount { - tmp = big.NewInt(2) + tmp.Set(intTwo) + } else { + tmp.Set(intOne) } tmp.Mul(tmp, voterReward) tmp.Div(tmp, cs[i].Votes) @@ -633,7 +641,7 @@ func (n *NEO) calculateBonus(d dao.DAO, vote *keys.PublicKey, value *big.Int, st var reward = n.getGASPerVote(d, key, start, end) var tmp = new(big.Int).Sub(&reward[1], &reward[0]) tmp.Mul(tmp, value) - tmp.Div(tmp, big.NewInt(voterRewardFactor)) + tmp.Div(tmp, bigVoterRewardFactor) tmp.Add(tmp, r) return tmp, nil } @@ -982,7 +990,7 @@ func (n *NEO) computeCommitteeMembers(bc blockchainer.Blockchainer, d dao.DAO) ( } votersCount := bigint.FromBytes(si) // votersCount / totalSupply must be >= 0.2 - votersCount.Mul(votersCount, big.NewInt(effectiveVoterTurnout)) + votersCount.Mul(votersCount, bigEffectiveVoterTurnout) _, totalSupply := n.getTotalSupply(d) voterTurnout := votersCount.Div(votersCount, totalSupply) diff --git a/pkg/core/native/util.go b/pkg/core/native/util.go index afb20728c..0fadf740d 100644 --- a/pkg/core/native/util.go +++ b/pkg/core/native/util.go @@ -13,6 +13,7 @@ import ( ) var intOne = big.NewInt(1) +var intTwo = big.NewInt(2) func getConvertibleFromDAO(id int32, d dao.DAO, key []byte, conv stackitem.Convertible) error { si := d.GetStorageItem(id, key) diff --git a/pkg/core/storage/memcached_store.go b/pkg/core/storage/memcached_store.go index 065c7303f..d9a6dcbbc 100644 --- a/pkg/core/storage/memcached_store.go +++ b/pkg/core/storage/memcached_store.go @@ -252,8 +252,8 @@ func (s *MemCachedStore) persist(isSync bool) (int, error) { // unprotected while writes are handled by s proper. var tempstore = &MemCachedStore{MemoryStore: MemoryStore{mem: s.mem, del: s.del}, ps: s.ps} s.ps = tempstore - s.mem = make(map[string][]byte) - s.del = make(map[string]bool) + s.mem = make(map[string][]byte, len(s.mem)) + s.del = make(map[string]bool, len(s.del)) if !isSync { s.mut.Unlock() } diff --git a/pkg/crypto/keys/publickey.go b/pkg/crypto/keys/publickey.go index f39a15b1d..ecbeee259 100644 --- a/pkg/crypto/keys/publickey.go +++ b/pkg/crypto/keys/publickey.go @@ -29,6 +29,9 @@ const SignatureLen = 64 // PublicKeys is a list of public keys. type PublicKeys []*PublicKey +var big0 = big.NewInt(0) +var big3 = big.NewInt(3) + func (keys PublicKeys) Len() int { return len(keys) } func (keys PublicKeys) Swap(i, j int) { keys[i], keys[j] = keys[j], keys[i] } func (keys PublicKeys) Less(i, j int) bool { @@ -189,12 +192,12 @@ func decodeCompressedY(x *big.Int, ylsb uint, curve elliptic.Curve) (*big.Int, e var a *big.Int switch curve.(type) { case *btcec.KoblitzCurve: - a = big.NewInt(0) + a = big0 default: - a = big.NewInt(3) + a = big3 } cp := curve.Params() - xCubed := new(big.Int).Exp(x, big.NewInt(3), cp.P) + xCubed := new(big.Int).Exp(x, big3, cp.P) aX := new(big.Int).Mul(x, a) aX.Mod(aX, cp.P) ySquared := new(big.Int).Sub(xCubed, aX) diff --git a/pkg/encoding/bigint/bigint.go b/pkg/encoding/bigint/bigint.go index b2ab067b0..a6b80b70a 100644 --- a/pkg/encoding/bigint/bigint.go +++ b/pkg/encoding/bigint/bigint.go @@ -15,6 +15,8 @@ const ( wordSizeBytes = bits.UintSize / 8 ) +var bigOne = big.NewInt(1) + // FromBytesUnsigned converts data in little-endian format to an unsigned integer. func FromBytesUnsigned(data []byte) *big.Int { bs := slice.CopyReverse(data) @@ -70,7 +72,7 @@ func FromBytes(data []byte) *big.Int { n.SetBits(ws) n.Neg(n) - return n.Sub(n, big.NewInt(1)) + return n.Sub(n, bigOne) } return n.SetBits(ws) @@ -114,7 +116,7 @@ func ToPreallocatedBytes(n *big.Int, data []byte) []byte { if sign == 1 { ws = n.Bits() } else { - n1 := new(big.Int).Add(n, big.NewInt(1)) + n1 := new(big.Int).Add(n, bigOne) if n1.Sign() == 0 { // n == -1 return append(data, 0xFF) } diff --git a/pkg/io/binaryWriter.go b/pkg/io/binaryWriter.go index 5f234c1b4..dc563b149 100644 --- a/pkg/io/binaryWriter.go +++ b/pkg/io/binaryWriter.go @@ -143,7 +143,11 @@ func (w *BinWriter) WriteVarBytes(b []byte) { // WriteString writes a variable length string into the underlying io.Writer. func (w *BinWriter) WriteString(s string) { - w.WriteVarBytes([]byte(s)) + w.WriteVarUint(uint64(len(s))) + if w.Err != nil { + return + } + _, w.Err = io.WriteString(w.w, s) } // Grow tries to increase underlying buffer capacity so that at least n bytes diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 1eaba0ce0..3cbace4b7 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -32,9 +32,9 @@ type Context struct { // Evaluation stack pointer. estack *Stack - static *Slot - local *Slot - arguments *Slot + static *slot + local slot + arguments slot // Exception context stack. tryStack Stack @@ -277,16 +277,19 @@ func (c *Context) DumpStaticSlot() string { // DumpLocalSlot returns json formatted representation of the given slot. func (c *Context) DumpLocalSlot() string { - return dumpSlot(c.local) + return dumpSlot(&c.local) } // DumpArgumentsSlot returns json formatted representation of the given slot. func (c *Context) DumpArgumentsSlot() string { - return dumpSlot(c.arguments) + return dumpSlot(&c.arguments) } // dumpSlot returns json formatted representation of the given slot. -func dumpSlot(s *Slot) string { +func dumpSlot(s *slot) string { + if s == nil || *s == nil { + return "[]" + } b, _ := json.MarshalIndent(s, "", " ") return string(b) } diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index 5d7f6acf5..bd15eba66 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -239,8 +239,8 @@ func compareStacks(t *testing.T, expected []vmUTStackItem, actual *Stack) { compareItemArrays(t, expected, actual.Len(), func(i int) stackitem.Item { return actual.Peek(i).Item() }) } -func compareSlots(t *testing.T, expected []vmUTStackItem, actual *Slot) { - if actual.storage == nil && len(expected) == 0 { +func compareSlots(t *testing.T, expected []vmUTStackItem, actual *slot) { + if (actual == nil || *actual == nil) && len(expected) == 0 { return } require.NotNil(t, actual) diff --git a/pkg/vm/opcodebench_test.go b/pkg/vm/opcodebench_test.go index 1657d4151..465a9943e 100644 --- a/pkg/vm/opcodebench_test.go +++ b/pkg/vm/opcodebench_test.go @@ -59,8 +59,8 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int, v.Context().static.init(sslot) } if slotloc != 0 && slotarg != 0 { - v.Context().local = v.newSlot(slotloc) - v.Context().arguments = v.newSlot(slotarg) + v.Context().local.init(slotloc) + v.Context().arguments.init(slotarg) } for i := range items { item, ok := items[i].(stackitem.Item) diff --git a/pkg/vm/slot.go b/pkg/vm/slot.go index 634891f18..132ee220c 100644 --- a/pkg/vm/slot.go +++ b/pkg/vm/slot.go @@ -6,75 +6,58 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) -// Slot is a fixed-size slice of stack items. -type Slot struct { - storage []stackitem.Item - refs *refCounter -} - -// newSlot returns new slot with the provided reference counter. -func newSlot(refs *refCounter) *Slot { - return &Slot{ - refs: refs, - } -} +// slot is a fixed-size slice of stack items. +type slot []stackitem.Item // init sets static slot size to n. It is intended to be used only by INITSSLOT. -func (s *Slot) init(n int) { - if s.storage != nil { +func (s *slot) init(n int) { + if *s != nil { panic("already initialized") } - s.storage = make([]stackitem.Item, n) -} - -func (v *VM) newSlot(n int) *Slot { - s := newSlot(&v.refs) - s.init(n) - return s + *s = make([]stackitem.Item, n) } // Set sets i-th storage slot. -func (s *Slot) Set(i int, item stackitem.Item) { - if s.storage[i] == item { +func (s slot) Set(i int, item stackitem.Item, refs *refCounter) { + if s[i] == item { return } - old := s.storage[i] - s.storage[i] = item + old := s[i] + s[i] = item if old != nil { - s.refs.Remove(old) + refs.Remove(old) } - s.refs.Add(item) + refs.Add(item) } // Get returns item contained in i-th slot. -func (s *Slot) Get(i int) stackitem.Item { - if item := s.storage[i]; item != nil { +func (s slot) Get(i int) stackitem.Item { + if item := s[i]; item != nil { return item } return stackitem.Null{} } // Clear removes all slot variables from reference counter. -func (s *Slot) Clear() { - for _, item := range s.storage { - s.refs.Remove(item) +func (s slot) Clear(refs *refCounter) { + for _, item := range s { + refs.Remove(item) } } // Size returns slot size. -func (s *Slot) Size() int { - if s.storage == nil { +func (s slot) Size() int { + if s == nil { panic("not initialized") } - return len(s.storage) + return len(s) } // MarshalJSON implements JSON marshalling interface. -func (s *Slot) MarshalJSON() ([]byte, error) { - items := s.storage - arr := make([]json.RawMessage, len(items)) - for i := range items { - data, err := stackitem.ToJSONWithTypes(items[i]) +func (s slot) MarshalJSON() ([]byte, error) { + arr := make([]json.RawMessage, len(s)) + for i := range s { + data, err := stackitem.ToJSONWithTypes(s[i]) if err == nil { arr[i] = data } diff --git a/pkg/vm/slot_test.go b/pkg/vm/slot_test.go index 434464476..212470a9f 100644 --- a/pkg/vm/slot_test.go +++ b/pkg/vm/slot_test.go @@ -9,8 +9,8 @@ import ( ) func TestSlot_Get(t *testing.T) { - s := newSlot(newRefCounter()) - require.NotNil(t, s) + rc := newRefCounter() + var s slot require.Panics(t, func() { s.Size() }) s.init(3) @@ -20,6 +20,6 @@ func TestSlot_Get(t *testing.T) { item := s.Get(2) require.Equal(t, stackitem.Null{}, item) - s.Set(1, stackitem.NewBigInteger(big.NewInt(42))) + s.Set(1, stackitem.NewBigInteger(big.NewInt(42)), rc) require.Equal(t, stackitem.NewBigInteger(big.NewInt(42)), s.Get(1)) } diff --git a/pkg/vm/stackitem/item.go b/pkg/vm/stackitem/item.go index 14ed782e8..8fccd7ca2 100644 --- a/pkg/vm/stackitem/item.go +++ b/pkg/vm/stackitem/item.go @@ -1097,7 +1097,7 @@ func (i *Buffer) Len() int { // Values of Interop items are not deeply copied. // It does preserve duplicates only for non-primitive types. func DeepCopy(item Item) Item { - seen := make(map[Item]Item) + seen := make(map[Item]Item, typicalNumOfItems) return deepCopy(item, seen) } diff --git a/pkg/vm/stackitem/json.go b/pkg/vm/stackitem/json.go index bb01076b0..9fd151d7a 100644 --- a/pkg/vm/stackitem/json.go +++ b/pkg/vm/stackitem/json.go @@ -42,7 +42,7 @@ var ErrTooDeep = errors.New("too deep") // Array, Struct -> array // Map -> map with keys as UTF-8 bytes func ToJSON(item Item) ([]byte, error) { - seen := make(map[Item]sliceNoPointer) + seen := make(map[Item]sliceNoPointer, typicalNumOfItems) return toJSON(nil, seen, item) } @@ -260,7 +260,7 @@ func (d *decoder) decodeMap() (*Map, error) { // ToJSONWithTypes serializes any stackitem to JSON in a lossless way. func ToJSONWithTypes(item Item) ([]byte, error) { - result, err := toJSONWithTypes(item, make(map[Item]bool)) + result, err := toJSONWithTypes(item, make(map[Item]bool, typicalNumOfItems)) if err != nil { return nil, err } diff --git a/pkg/vm/stackitem/serialization.go b/pkg/vm/stackitem/serialization.go index bda4fbc80..50c5a5f2e 100644 --- a/pkg/vm/stackitem/serialization.go +++ b/pkg/vm/stackitem/serialization.go @@ -13,6 +13,12 @@ import ( // (including itself). const MaxDeserialized = 2048 +// typicalNumOfItems is the number of items covering most serializaton needs. +// It's a hint used for map creation, so it's not limiting anything, it's just +// a microoptimization to avoid excessive reallocations. Most of the serialized +// items are structs, so there is at least one of them. +const typicalNumOfItems = 4 + // ErrRecursive is returned on attempts to serialize some recursive stack item // (like array including an item with reference to the same array). var ErrRecursive = errors.New("recursive item") @@ -40,7 +46,7 @@ type deserContext struct { func Serialize(item Item) ([]byte, error) { sc := serContext{ allowInvalid: false, - seen: make(map[Item]sliceNoPointer), + seen: make(map[Item]sliceNoPointer, typicalNumOfItems), } err := sc.serialize(item) if err != nil { @@ -69,7 +75,7 @@ func EncodeBinary(item Item, w *io.BinWriter) { func EncodeBinaryProtected(item Item, w *io.BinWriter) { sc := serContext{ allowInvalid: true, - seen: make(map[Item]sliceNoPointer), + seen: make(map[Item]sliceNoPointer, typicalNumOfItems), } err := sc.serialize(item) if err != nil { @@ -79,6 +85,18 @@ func EncodeBinaryProtected(item Item, w *io.BinWriter) { w.WriteBytes(sc.data) } +func (w *serContext) writeArray(item Item, arr []Item, start int) error { + w.seen[item] = sliceNoPointer{} + w.appendVarUint(uint64(len(arr))) + for i := range arr { + if err := w.serialize(arr[i]); err != nil { + return err + } + } + w.seen[item] = sliceNoPointer{start, len(w.data)} + return nil +} + func (w *serContext) serialize(item Item) error { if v, ok := w.seen[item]; ok { if v.start == v.end { @@ -121,28 +139,20 @@ func (w *serContext) serialize(item Item) error { } else { return fmt.Errorf("%w: Interop", ErrUnserializable) } - case *Array, *Struct: - w.seen[item] = sliceNoPointer{} - - _, isArray := t.(*Array) - if isArray { - w.data = append(w.data, byte(ArrayT)) - } else { - w.data = append(w.data, byte(StructT)) + case *Array: + w.data = append(w.data, byte(ArrayT)) + if err := w.writeArray(item, t.value, start); err != nil { + return err } - - arr := t.Value().([]Item) - w.appendVarUint(uint64(len(arr))) - for i := range arr { - if err := w.serialize(arr[i]); err != nil { - return err - } + case *Struct: + w.data = append(w.data, byte(StructT)) + if err := w.writeArray(item, t.value, start); err != nil { + return err } - w.seen[item] = sliceNoPointer{start, len(w.data)} case *Map: w.seen[item] = sliceNoPointer{} - elems := t.Value().([]MapElement) + elems := t.value w.data = append(w.data, byte(MapT)) w.appendVarUint(uint64(len(elems))) for i := range elems { diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 960b0e4f7..33859b91c 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -90,6 +90,8 @@ type VM struct { invTree *InvocationTree } +var bigOne = big.NewInt(1) + // New returns a new VM object ready to load AVM bytecode scripts. func New() *VM { return NewWithTrigger(trigger.Application) @@ -105,6 +107,7 @@ func NewWithTrigger(t trigger.Type) *VM { } initStack(&vm.istack, "invocation", nil) + vm.istack.elems = make([]Element, 0, 8) // Most of invocations use one-two contracts, but they're likely to have internal calls. vm.estack = newStack("evaluation", &vm.refs) return vm } @@ -310,13 +313,15 @@ func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160 // It should be used for calling from native contracts. func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160, hash util.Uint160, f callflag.CallFlag, rvcount int, offset int) { + var sl slot + v.checkInvocationStackSize() ctx := NewContextWithParams(b, rvcount, offset) v.estack = newStack("evaluation", &v.refs) ctx.estack = v.estack initStack(&ctx.tryStack, "exception", nil) ctx.callFlag = f - ctx.static = newSlot(&v.refs) + ctx.static = &sl ctx.scriptHash = hash ctx.callingScriptHash = caller ctx.NEF = exe @@ -615,13 +620,13 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro panic("zero argument") } if parameter[0] > 0 { - ctx.local = v.newSlot(int(parameter[0])) + ctx.local.init(int(parameter[0])) } if parameter[1] > 0 { sz := int(parameter[1]) - ctx.arguments = v.newSlot(sz) + ctx.arguments.init(sz) for i := 0; i < sz; i++ { - ctx.arguments.Set(i, v.estack.Pop().Item()) + ctx.arguments.Set(i, v.estack.Pop().Item(), &v.refs) } } @@ -635,11 +640,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.STSFLD0, opcode.STSFLD1, opcode.STSFLD2, opcode.STSFLD3, opcode.STSFLD4, opcode.STSFLD5, opcode.STSFLD6: item := v.estack.Pop().Item() - ctx.static.Set(int(op-opcode.STSFLD0), item) + ctx.static.Set(int(op-opcode.STSFLD0), item, &v.refs) case opcode.STSFLD: item := v.estack.Pop().Item() - ctx.static.Set(int(parameter[0]), item) + ctx.static.Set(int(parameter[0]), item, &v.refs) case opcode.LDLOC0, opcode.LDLOC1, opcode.LDLOC2, opcode.LDLOC3, opcode.LDLOC4, opcode.LDLOC5, opcode.LDLOC6: item := ctx.local.Get(int(op - opcode.LDLOC0)) @@ -651,11 +656,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.STLOC0, opcode.STLOC1, opcode.STLOC2, opcode.STLOC3, opcode.STLOC4, opcode.STLOC5, opcode.STLOC6: item := v.estack.Pop().Item() - ctx.local.Set(int(op-opcode.STLOC0), item) + ctx.local.Set(int(op-opcode.STLOC0), item, &v.refs) case opcode.STLOC: item := v.estack.Pop().Item() - ctx.local.Set(int(parameter[0]), item) + ctx.local.Set(int(parameter[0]), item, &v.refs) case opcode.LDARG0, opcode.LDARG1, opcode.LDARG2, opcode.LDARG3, opcode.LDARG4, opcode.LDARG5, opcode.LDARG6: item := ctx.arguments.Get(int(op - opcode.LDARG0)) @@ -667,11 +672,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.STARG0, opcode.STARG1, opcode.STARG2, opcode.STARG3, opcode.STARG4, opcode.STARG5, opcode.STARG6: item := v.estack.Pop().Item() - ctx.arguments.Set(int(op-opcode.STARG0), item) + ctx.arguments.Set(int(op-opcode.STARG0), item, &v.refs) case opcode.STARG: item := v.estack.Pop().Item() - ctx.arguments.Set(int(parameter[0]), item) + ctx.arguments.Set(int(parameter[0]), item, &v.refs) case opcode.NEWBUFFER: n := toInt(v.estack.Pop().BigInt()) @@ -887,12 +892,12 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.INC: x := v.estack.Pop().BigInt() - a := new(big.Int).Add(x, big.NewInt(1)) + a := new(big.Int).Add(x, bigOne) v.estack.PushItem(stackitem.NewBigInteger(a)) case opcode.DEC: x := v.estack.Pop().BigInt() - a := new(big.Int).Sub(x, big.NewInt(1)) + a := new(big.Int).Sub(x, bigOne) v.estack.PushItem(stackitem.NewBigInteger(a)) case opcode.ADD: @@ -1527,14 +1532,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro func (v *VM) unloadContext(ctx *Context) { if ctx.local != nil { - ctx.local.Clear() + ctx.local.Clear(&v.refs) } if ctx.arguments != nil { - ctx.arguments.Clear() + ctx.arguments.Clear(&v.refs) } currCtx := v.Context() if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { - ctx.static.Clear() + ctx.static.Clear(&v.refs) } }