Merge pull request #2287 from nspcc-dev/optimizations-partN

Optimizations part N
This commit is contained in:
Roman Khimov 2021-12-02 11:54:56 +03:00 committed by GitHub
commit d62f5b8c3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 158 additions and 137 deletions

1
go.mod
View file

@ -6,6 +6,7 @@ require (
github.com/btcsuite/btcd v0.22.0-beta github.com/btcsuite/btcd v0.22.0-beta
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/hashicorp/golang-lru v0.5.4 github.com/hashicorp/golang-lru v0.5.4
github.com/holiman/uint256 v1.2.0
github.com/mr-tron/base58 v1.2.0 github.com/mr-tron/base58 v1.2.0
github.com/nspcc-dev/dbft v0.0.0-20210721160347-1b03241391ac github.com/nspcc-dev/dbft v0.0.0-20210721160347-1b03241391ac
github.com/nspcc-dev/go-ordered-json v0.0.0-20210915112629-e1b6cce73d02 github.com/nspcc-dev/go-ordered-json v0.0.0-20210915112629-e1b6cce73d02

2
go.sum
View file

@ -128,6 +128,8 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/holiman/uint256 v1.2.0 h1:gpSYcPLWGv4sG43I2mVLiDZCNDh/EpGjSk8tmtxitHM=
github.com/holiman/uint256 v1.2.0/go.mod h1:y4ga/t+u+Xwd7CpDgZESaRcWy0I7XMlTMA25ApIH5Jw=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=

View file

@ -3,11 +3,11 @@ package mempool
import ( import (
"errors" "errors"
"fmt" "fmt"
"math/big"
"math/bits" "math/bits"
"sort" "sort"
"sync" "sync"
"github.com/holiman/uint256"
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -50,8 +50,8 @@ type items []item
// utilityBalanceAndFees stores sender's balance and overall fees of // utilityBalanceAndFees stores sender's balance and overall fees of
// sender's transactions which are currently in mempool. // sender's transactions which are currently in mempool.
type utilityBalanceAndFees struct { type utilityBalanceAndFees struct {
balance *big.Int balance uint256.Int
feeSum *big.Int feeSum uint256.Int
} }
// Pool stores the unconfirms transactions. // Pool stores the unconfirms transactions.
@ -164,8 +164,7 @@ func (mp *Pool) tryAddSendersFee(tx *transaction.Transaction, feer Feer, needChe
payer := tx.Signers[mp.payerIndex].Account payer := tx.Signers[mp.payerIndex].Account
senderFee, ok := mp.fees[payer] senderFee, ok := mp.fees[payer]
if !ok { if !ok {
senderFee.balance = feer.GetUtilityTokenBalance(payer) _ = senderFee.balance.SetFromBig(feer.GetUtilityTokenBalance(payer))
senderFee.feeSum = big.NewInt(0)
mp.fees[payer] = senderFee mp.fees[payer] = senderFee
} }
if needCheck { if needCheck {
@ -173,23 +172,26 @@ func (mp *Pool) tryAddSendersFee(tx *transaction.Transaction, feer Feer, needChe
if err != nil { if err != nil {
return false return false
} }
senderFee.feeSum.Set(newFeeSum) senderFee.feeSum = newFeeSum
} else { } else {
senderFee.feeSum.Add(senderFee.feeSum, big.NewInt(tx.SystemFee+tx.NetworkFee)) senderFee.feeSum.AddUint64(&senderFee.feeSum, uint64(tx.SystemFee+tx.NetworkFee))
} }
mp.fees[payer] = senderFee
return true return true
} }
// checkBalance returns new cumulative fee balance for account or an error in // checkBalance returns new cumulative fee balance for account or an error in
// case sender doesn't have enough GAS to pay for the transaction. // case sender doesn't have enough GAS to pay for the transaction.
func checkBalance(tx *transaction.Transaction, balance utilityBalanceAndFees) (*big.Int, error) { func checkBalance(tx *transaction.Transaction, balance utilityBalanceAndFees) (uint256.Int, error) {
txFee := big.NewInt(tx.SystemFee + tx.NetworkFee) var txFee uint256.Int
if balance.balance.Cmp(txFee) < 0 {
return nil, ErrInsufficientFunds txFee.SetUint64(uint64(tx.SystemFee + tx.NetworkFee))
if balance.balance.Cmp(&txFee) < 0 {
return txFee, ErrInsufficientFunds
} }
txFee.Add(txFee, balance.feeSum) txFee.Add(&txFee, &balance.feeSum)
if balance.balance.Cmp(txFee) < 0 { if balance.balance.Cmp(&txFee) < 0 {
return nil, ErrConflict return txFee, ErrConflict
} }
return txFee, nil return txFee, nil
} }
@ -323,7 +325,7 @@ func (mp *Pool) removeInternal(hash util.Uint256, feer Feer) {
} }
payer := itm.txn.Signers[mp.payerIndex].Account payer := itm.txn.Signers[mp.payerIndex].Account
senderFee := mp.fees[payer] senderFee := mp.fees[payer]
senderFee.feeSum.Sub(senderFee.feeSum, big.NewInt(tx.SystemFee+tx.NetworkFee)) senderFee.feeSum.SubUint64(&senderFee.feeSum, uint64(tx.SystemFee+tx.NetworkFee))
mp.fees[payer] = senderFee mp.fees[payer] = senderFee
if feer.P2PSigExtensionsEnabled() { if feer.P2PSigExtensionsEnabled() {
// remove all conflicting hashes from mp.conflicts list // remove all conflicting hashes from mp.conflicts list
@ -420,7 +422,7 @@ func (mp *Pool) checkPolicy(tx *transaction.Transaction, policyChanged bool) boo
// New returns a new Pool struct. // New returns a new Pool struct.
func New(capacity int, payerIndex int, enableSubscriptions bool) *Pool { func New(capacity int, payerIndex int, enableSubscriptions bool) *Pool {
mp := &Pool{ mp := &Pool{
verifiedMap: make(map[util.Uint256]*transaction.Transaction), verifiedMap: make(map[util.Uint256]*transaction.Transaction, capacity),
verifiedTxes: make([]item, 0, capacity), verifiedTxes: make([]item, 0, capacity),
capacity: capacity, capacity: capacity,
payerIndex: payerIndex, payerIndex: payerIndex,
@ -507,8 +509,7 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*tran
payer := tx.Signers[mp.payerIndex].Account payer := tx.Signers[mp.payerIndex].Account
actualSenderFee, ok := mp.fees[payer] actualSenderFee, ok := mp.fees[payer]
if !ok { if !ok {
actualSenderFee.balance = fee.GetUtilityTokenBalance(payer) actualSenderFee.balance.SetFromBig(fee.GetUtilityTokenBalance(payer))
actualSenderFee.feeSum = big.NewInt(0)
} }
var expectedSenderFee utilityBalanceAndFees var expectedSenderFee utilityBalanceAndFees
@ -541,13 +542,10 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*tran
conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx) conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx)
} }
// Step 3: take into account sender's conflicting transactions before balance check. // Step 3: take into account sender's conflicting transactions before balance check.
expectedSenderFee = utilityBalanceAndFees{ expectedSenderFee = actualSenderFee
balance: new(big.Int).Set(actualSenderFee.balance),
feeSum: new(big.Int).Set(actualSenderFee.feeSum),
}
for _, conflictingTx := range conflictsToBeRemoved { for _, conflictingTx := range conflictsToBeRemoved {
if conflictingTx.Signers[mp.payerIndex].Account.Equals(payer) { if conflictingTx.Signers[mp.payerIndex].Account.Equals(payer) {
expectedSenderFee.feeSum.Sub(expectedSenderFee.feeSum, big.NewInt(conflictingTx.SystemFee+conflictingTx.NetworkFee)) expectedSenderFee.feeSum.SubUint64(&expectedSenderFee.feeSum, uint64(conflictingTx.SystemFee+conflictingTx.NetworkFee))
} }
} }
} else { } else {

View file

@ -7,6 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/holiman/uint256"
"github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
@ -268,8 +269,8 @@ func TestMemPoolFees(t *testing.T) {
require.NoError(t, mp.Add(tx1, fs)) require.NoError(t, mp.Add(tx1, fs))
require.Equal(t, 1, len(mp.fees)) require.Equal(t, 1, len(mp.fees))
require.Equal(t, utilityBalanceAndFees{ require.Equal(t, utilityBalanceAndFees{
balance: big.NewInt(fs.balance), balance: *uint256.NewInt(uint64(fs.balance)),
feeSum: big.NewInt(tx1.NetworkFee), feeSum: *uint256.NewInt(uint64(tx1.NetworkFee)),
}, mp.fees[sender0]) }, mp.fees[sender0])
// balance shouldn't change after adding one more transaction // balance shouldn't change after adding one more transaction
@ -280,8 +281,8 @@ func TestMemPoolFees(t *testing.T) {
require.Equal(t, 2, len(mp.verifiedTxes)) require.Equal(t, 2, len(mp.verifiedTxes))
require.Equal(t, 1, len(mp.fees)) require.Equal(t, 1, len(mp.fees))
require.Equal(t, utilityBalanceAndFees{ require.Equal(t, utilityBalanceAndFees{
balance: big.NewInt(fs.balance), balance: *uint256.NewInt(uint64(fs.balance)),
feeSum: big.NewInt(fs.balance), feeSum: *uint256.NewInt(uint64(fs.balance)),
}, mp.fees[sender0]) }, mp.fees[sender0])
// can't add more transactions as we don't have enough GAS // can't add more transactions as we don't have enough GAS
@ -292,8 +293,8 @@ func TestMemPoolFees(t *testing.T) {
require.Error(t, mp.Add(tx3, fs)) require.Error(t, mp.Add(tx3, fs))
require.Equal(t, 1, len(mp.fees)) require.Equal(t, 1, len(mp.fees))
require.Equal(t, utilityBalanceAndFees{ require.Equal(t, utilityBalanceAndFees{
balance: big.NewInt(fs.balance), balance: *uint256.NewInt(uint64(fs.balance)),
feeSum: big.NewInt(fs.balance), feeSum: *uint256.NewInt(uint64(fs.balance)),
}, mp.fees[sender0]) }, mp.fees[sender0])
// check whether sender's fee updates correctly // check whether sender's fee updates correctly
@ -302,8 +303,8 @@ func TestMemPoolFees(t *testing.T) {
}, fs) }, fs)
require.Equal(t, 1, len(mp.fees)) require.Equal(t, 1, len(mp.fees))
require.Equal(t, utilityBalanceAndFees{ require.Equal(t, utilityBalanceAndFees{
balance: big.NewInt(fs.balance), balance: *uint256.NewInt(uint64(fs.balance)),
feeSum: big.NewInt(tx2.NetworkFee), feeSum: *uint256.NewInt(uint64(tx2.NetworkFee)),
}, mp.fees[sender0]) }, mp.fees[sender0])
// there should be nothing left // there should be nothing left

View file

@ -90,6 +90,12 @@ const (
var ( var (
// prefixCommittee is a key used to store committee. // prefixCommittee is a key used to store committee.
prefixCommittee = []byte{14} prefixCommittee = []byte{14}
bigCommitteeRewardRatio = big.NewInt(committeeRewardRatio)
bigVoterRewardRatio = big.NewInt(voterRewardRatio)
bigVoterRewardFactor = big.NewInt(voterRewardFactor)
bigEffectiveVoterTurnout = big.NewInt(effectiveVoterTurnout)
big100 = big.NewInt(100)
) )
// makeValidatorKey creates a key from account script hash. // makeValidatorKey creates a key from account script hash.
@ -316,24 +322,26 @@ func (n *NEO) PostPersist(ic *interop.Context) error {
pubs := n.GetCommitteeMembers() pubs := n.GetCommitteeMembers()
committeeSize := len(ic.Chain.GetConfig().StandbyCommittee) committeeSize := len(ic.Chain.GetConfig().StandbyCommittee)
index := int(ic.Block.Index) % committeeSize index := int(ic.Block.Index) % committeeSize
committeeReward := new(big.Int).Mul(gas, big.NewInt(committeeRewardRatio)) committeeReward := new(big.Int).Mul(gas, bigCommitteeRewardRatio)
n.GAS.mint(ic, pubs[index].GetScriptHash(), committeeReward.Div(committeeReward, big.NewInt(100)), false) n.GAS.mint(ic, pubs[index].GetScriptHash(), committeeReward.Div(committeeReward, big100), false)
if ShouldUpdateCommittee(ic.Block.Index, ic.Chain) { if ShouldUpdateCommittee(ic.Block.Index, ic.Chain) {
var voterReward = big.NewInt(voterRewardRatio) var voterReward = new(big.Int).Set(bigVoterRewardRatio)
voterReward.Mul(voterReward, gas) voterReward.Mul(voterReward, gas)
voterReward.Mul(voterReward, big.NewInt(voterRewardFactor*int64(committeeSize))) voterReward.Mul(voterReward, big.NewInt(voterRewardFactor*int64(committeeSize)))
var validatorsCount = ic.Chain.GetConfig().ValidatorsCount var validatorsCount = ic.Chain.GetConfig().ValidatorsCount
voterReward.Div(voterReward, big.NewInt(int64(committeeSize+validatorsCount))) voterReward.Div(voterReward, big.NewInt(int64(committeeSize+validatorsCount)))
voterReward.Div(voterReward, big.NewInt(100)) voterReward.Div(voterReward, big100)
var cs = n.committee.Load().(keysWithVotes) var cs = n.committee.Load().(keysWithVotes)
var key = make([]byte, 38) var key = make([]byte, 38)
for i := range cs { for i := range cs {
if cs[i].Votes.Sign() > 0 { if cs[i].Votes.Sign() > 0 {
tmp := big.NewInt(1) var tmp = new(big.Int)
if i < validatorsCount { if i < validatorsCount {
tmp = big.NewInt(2) tmp.Set(intTwo)
} else {
tmp.Set(intOne)
} }
tmp.Mul(tmp, voterReward) tmp.Mul(tmp, voterReward)
tmp.Div(tmp, cs[i].Votes) tmp.Div(tmp, cs[i].Votes)
@ -633,7 +641,7 @@ func (n *NEO) calculateBonus(d dao.DAO, vote *keys.PublicKey, value *big.Int, st
var reward = n.getGASPerVote(d, key, start, end) var reward = n.getGASPerVote(d, key, start, end)
var tmp = new(big.Int).Sub(&reward[1], &reward[0]) var tmp = new(big.Int).Sub(&reward[1], &reward[0])
tmp.Mul(tmp, value) tmp.Mul(tmp, value)
tmp.Div(tmp, big.NewInt(voterRewardFactor)) tmp.Div(tmp, bigVoterRewardFactor)
tmp.Add(tmp, r) tmp.Add(tmp, r)
return tmp, nil return tmp, nil
} }
@ -982,7 +990,7 @@ func (n *NEO) computeCommitteeMembers(bc blockchainer.Blockchainer, d dao.DAO) (
} }
votersCount := bigint.FromBytes(si) votersCount := bigint.FromBytes(si)
// votersCount / totalSupply must be >= 0.2 // votersCount / totalSupply must be >= 0.2
votersCount.Mul(votersCount, big.NewInt(effectiveVoterTurnout)) votersCount.Mul(votersCount, bigEffectiveVoterTurnout)
_, totalSupply := n.getTotalSupply(d) _, totalSupply := n.getTotalSupply(d)
voterTurnout := votersCount.Div(votersCount, totalSupply) voterTurnout := votersCount.Div(votersCount, totalSupply)

View file

@ -13,6 +13,7 @@ import (
) )
var intOne = big.NewInt(1) var intOne = big.NewInt(1)
var intTwo = big.NewInt(2)
func getConvertibleFromDAO(id int32, d dao.DAO, key []byte, conv stackitem.Convertible) error { func getConvertibleFromDAO(id int32, d dao.DAO, key []byte, conv stackitem.Convertible) error {
si := d.GetStorageItem(id, key) si := d.GetStorageItem(id, key)

View file

@ -252,8 +252,8 @@ func (s *MemCachedStore) persist(isSync bool) (int, error) {
// unprotected while writes are handled by s proper. // unprotected while writes are handled by s proper.
var tempstore = &MemCachedStore{MemoryStore: MemoryStore{mem: s.mem, del: s.del}, ps: s.ps} var tempstore = &MemCachedStore{MemoryStore: MemoryStore{mem: s.mem, del: s.del}, ps: s.ps}
s.ps = tempstore s.ps = tempstore
s.mem = make(map[string][]byte) s.mem = make(map[string][]byte, len(s.mem))
s.del = make(map[string]bool) s.del = make(map[string]bool, len(s.del))
if !isSync { if !isSync {
s.mut.Unlock() s.mut.Unlock()
} }

View file

@ -29,6 +29,9 @@ const SignatureLen = 64
// PublicKeys is a list of public keys. // PublicKeys is a list of public keys.
type PublicKeys []*PublicKey type PublicKeys []*PublicKey
var big0 = big.NewInt(0)
var big3 = big.NewInt(3)
func (keys PublicKeys) Len() int { return len(keys) } func (keys PublicKeys) Len() int { return len(keys) }
func (keys PublicKeys) Swap(i, j int) { keys[i], keys[j] = keys[j], keys[i] } func (keys PublicKeys) Swap(i, j int) { keys[i], keys[j] = keys[j], keys[i] }
func (keys PublicKeys) Less(i, j int) bool { func (keys PublicKeys) Less(i, j int) bool {
@ -189,12 +192,12 @@ func decodeCompressedY(x *big.Int, ylsb uint, curve elliptic.Curve) (*big.Int, e
var a *big.Int var a *big.Int
switch curve.(type) { switch curve.(type) {
case *btcec.KoblitzCurve: case *btcec.KoblitzCurve:
a = big.NewInt(0) a = big0
default: default:
a = big.NewInt(3) a = big3
} }
cp := curve.Params() cp := curve.Params()
xCubed := new(big.Int).Exp(x, big.NewInt(3), cp.P) xCubed := new(big.Int).Exp(x, big3, cp.P)
aX := new(big.Int).Mul(x, a) aX := new(big.Int).Mul(x, a)
aX.Mod(aX, cp.P) aX.Mod(aX, cp.P)
ySquared := new(big.Int).Sub(xCubed, aX) ySquared := new(big.Int).Sub(xCubed, aX)

View file

@ -15,6 +15,8 @@ const (
wordSizeBytes = bits.UintSize / 8 wordSizeBytes = bits.UintSize / 8
) )
var bigOne = big.NewInt(1)
// FromBytesUnsigned converts data in little-endian format to an unsigned integer. // FromBytesUnsigned converts data in little-endian format to an unsigned integer.
func FromBytesUnsigned(data []byte) *big.Int { func FromBytesUnsigned(data []byte) *big.Int {
bs := slice.CopyReverse(data) bs := slice.CopyReverse(data)
@ -70,7 +72,7 @@ func FromBytes(data []byte) *big.Int {
n.SetBits(ws) n.SetBits(ws)
n.Neg(n) n.Neg(n)
return n.Sub(n, big.NewInt(1)) return n.Sub(n, bigOne)
} }
return n.SetBits(ws) return n.SetBits(ws)
@ -114,7 +116,7 @@ func ToPreallocatedBytes(n *big.Int, data []byte) []byte {
if sign == 1 { if sign == 1 {
ws = n.Bits() ws = n.Bits()
} else { } else {
n1 := new(big.Int).Add(n, big.NewInt(1)) n1 := new(big.Int).Add(n, bigOne)
if n1.Sign() == 0 { // n == -1 if n1.Sign() == 0 { // n == -1
return append(data, 0xFF) return append(data, 0xFF)
} }

View file

@ -143,7 +143,11 @@ func (w *BinWriter) WriteVarBytes(b []byte) {
// WriteString writes a variable length string into the underlying io.Writer. // WriteString writes a variable length string into the underlying io.Writer.
func (w *BinWriter) WriteString(s string) { func (w *BinWriter) WriteString(s string) {
w.WriteVarBytes([]byte(s)) w.WriteVarUint(uint64(len(s)))
if w.Err != nil {
return
}
_, w.Err = io.WriteString(w.w, s)
} }
// Grow tries to increase underlying buffer capacity so that at least n bytes // Grow tries to increase underlying buffer capacity so that at least n bytes

View file

@ -32,9 +32,9 @@ type Context struct {
// Evaluation stack pointer. // Evaluation stack pointer.
estack *Stack estack *Stack
static *Slot static *slot
local *Slot local slot
arguments *Slot arguments slot
// Exception context stack. // Exception context stack.
tryStack Stack tryStack Stack
@ -277,16 +277,19 @@ func (c *Context) DumpStaticSlot() string {
// DumpLocalSlot returns json formatted representation of the given slot. // DumpLocalSlot returns json formatted representation of the given slot.
func (c *Context) DumpLocalSlot() string { func (c *Context) DumpLocalSlot() string {
return dumpSlot(c.local) return dumpSlot(&c.local)
} }
// DumpArgumentsSlot returns json formatted representation of the given slot. // DumpArgumentsSlot returns json formatted representation of the given slot.
func (c *Context) DumpArgumentsSlot() string { func (c *Context) DumpArgumentsSlot() string {
return dumpSlot(c.arguments) return dumpSlot(&c.arguments)
} }
// dumpSlot returns json formatted representation of the given slot. // dumpSlot returns json formatted representation of the given slot.
func dumpSlot(s *Slot) string { func dumpSlot(s *slot) string {
if s == nil || *s == nil {
return "[]"
}
b, _ := json.MarshalIndent(s, "", " ") b, _ := json.MarshalIndent(s, "", " ")
return string(b) return string(b)
} }

View file

@ -239,8 +239,8 @@ func compareStacks(t *testing.T, expected []vmUTStackItem, actual *Stack) {
compareItemArrays(t, expected, actual.Len(), func(i int) stackitem.Item { return actual.Peek(i).Item() }) compareItemArrays(t, expected, actual.Len(), func(i int) stackitem.Item { return actual.Peek(i).Item() })
} }
func compareSlots(t *testing.T, expected []vmUTStackItem, actual *Slot) { func compareSlots(t *testing.T, expected []vmUTStackItem, actual *slot) {
if actual.storage == nil && len(expected) == 0 { if (actual == nil || *actual == nil) && len(expected) == 0 {
return return
} }
require.NotNil(t, actual) require.NotNil(t, actual)

View file

@ -59,8 +59,8 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int,
v.Context().static.init(sslot) v.Context().static.init(sslot)
} }
if slotloc != 0 && slotarg != 0 { if slotloc != 0 && slotarg != 0 {
v.Context().local = v.newSlot(slotloc) v.Context().local.init(slotloc)
v.Context().arguments = v.newSlot(slotarg) v.Context().arguments.init(slotarg)
} }
for i := range items { for i := range items {
item, ok := items[i].(stackitem.Item) item, ok := items[i].(stackitem.Item)

View file

@ -6,75 +6,58 @@ import (
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
// Slot is a fixed-size slice of stack items. // slot is a fixed-size slice of stack items.
type Slot struct { type slot []stackitem.Item
storage []stackitem.Item
refs *refCounter
}
// newSlot returns new slot with the provided reference counter.
func newSlot(refs *refCounter) *Slot {
return &Slot{
refs: refs,
}
}
// init sets static slot size to n. It is intended to be used only by INITSSLOT. // init sets static slot size to n. It is intended to be used only by INITSSLOT.
func (s *Slot) init(n int) { func (s *slot) init(n int) {
if s.storage != nil { if *s != nil {
panic("already initialized") panic("already initialized")
} }
s.storage = make([]stackitem.Item, n) *s = make([]stackitem.Item, n)
}
func (v *VM) newSlot(n int) *Slot {
s := newSlot(&v.refs)
s.init(n)
return s
} }
// Set sets i-th storage slot. // Set sets i-th storage slot.
func (s *Slot) Set(i int, item stackitem.Item) { func (s slot) Set(i int, item stackitem.Item, refs *refCounter) {
if s.storage[i] == item { if s[i] == item {
return return
} }
old := s.storage[i] old := s[i]
s.storage[i] = item s[i] = item
if old != nil { if old != nil {
s.refs.Remove(old) refs.Remove(old)
} }
s.refs.Add(item) refs.Add(item)
} }
// Get returns item contained in i-th slot. // Get returns item contained in i-th slot.
func (s *Slot) Get(i int) stackitem.Item { func (s slot) Get(i int) stackitem.Item {
if item := s.storage[i]; item != nil { if item := s[i]; item != nil {
return item return item
} }
return stackitem.Null{} return stackitem.Null{}
} }
// Clear removes all slot variables from reference counter. // Clear removes all slot variables from reference counter.
func (s *Slot) Clear() { func (s slot) Clear(refs *refCounter) {
for _, item := range s.storage { for _, item := range s {
s.refs.Remove(item) refs.Remove(item)
} }
} }
// Size returns slot size. // Size returns slot size.
func (s *Slot) Size() int { func (s slot) Size() int {
if s.storage == nil { if s == nil {
panic("not initialized") panic("not initialized")
} }
return len(s.storage) return len(s)
} }
// MarshalJSON implements JSON marshalling interface. // MarshalJSON implements JSON marshalling interface.
func (s *Slot) MarshalJSON() ([]byte, error) { func (s slot) MarshalJSON() ([]byte, error) {
items := s.storage arr := make([]json.RawMessage, len(s))
arr := make([]json.RawMessage, len(items)) for i := range s {
for i := range items { data, err := stackitem.ToJSONWithTypes(s[i])
data, err := stackitem.ToJSONWithTypes(items[i])
if err == nil { if err == nil {
arr[i] = data arr[i] = data
} }

View file

@ -9,8 +9,8 @@ import (
) )
func TestSlot_Get(t *testing.T) { func TestSlot_Get(t *testing.T) {
s := newSlot(newRefCounter()) rc := newRefCounter()
require.NotNil(t, s) var s slot
require.Panics(t, func() { s.Size() }) require.Panics(t, func() { s.Size() })
s.init(3) s.init(3)
@ -20,6 +20,6 @@ func TestSlot_Get(t *testing.T) {
item := s.Get(2) item := s.Get(2)
require.Equal(t, stackitem.Null{}, item) require.Equal(t, stackitem.Null{}, item)
s.Set(1, stackitem.NewBigInteger(big.NewInt(42))) s.Set(1, stackitem.NewBigInteger(big.NewInt(42)), rc)
require.Equal(t, stackitem.NewBigInteger(big.NewInt(42)), s.Get(1)) require.Equal(t, stackitem.NewBigInteger(big.NewInt(42)), s.Get(1))
} }

View file

@ -1097,7 +1097,7 @@ func (i *Buffer) Len() int {
// Values of Interop items are not deeply copied. // Values of Interop items are not deeply copied.
// It does preserve duplicates only for non-primitive types. // It does preserve duplicates only for non-primitive types.
func DeepCopy(item Item) Item { func DeepCopy(item Item) Item {
seen := make(map[Item]Item) seen := make(map[Item]Item, typicalNumOfItems)
return deepCopy(item, seen) return deepCopy(item, seen)
} }

View file

@ -42,7 +42,7 @@ var ErrTooDeep = errors.New("too deep")
// Array, Struct -> array // Array, Struct -> array
// Map -> map with keys as UTF-8 bytes // Map -> map with keys as UTF-8 bytes
func ToJSON(item Item) ([]byte, error) { func ToJSON(item Item) ([]byte, error) {
seen := make(map[Item]sliceNoPointer) seen := make(map[Item]sliceNoPointer, typicalNumOfItems)
return toJSON(nil, seen, item) return toJSON(nil, seen, item)
} }
@ -260,7 +260,7 @@ func (d *decoder) decodeMap() (*Map, error) {
// ToJSONWithTypes serializes any stackitem to JSON in a lossless way. // ToJSONWithTypes serializes any stackitem to JSON in a lossless way.
func ToJSONWithTypes(item Item) ([]byte, error) { func ToJSONWithTypes(item Item) ([]byte, error) {
result, err := toJSONWithTypes(item, make(map[Item]bool)) result, err := toJSONWithTypes(item, make(map[Item]bool, typicalNumOfItems))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -13,6 +13,12 @@ import (
// (including itself). // (including itself).
const MaxDeserialized = 2048 const MaxDeserialized = 2048
// typicalNumOfItems is the number of items covering most serializaton needs.
// It's a hint used for map creation, so it's not limiting anything, it's just
// a microoptimization to avoid excessive reallocations. Most of the serialized
// items are structs, so there is at least one of them.
const typicalNumOfItems = 4
// ErrRecursive is returned on attempts to serialize some recursive stack item // ErrRecursive is returned on attempts to serialize some recursive stack item
// (like array including an item with reference to the same array). // (like array including an item with reference to the same array).
var ErrRecursive = errors.New("recursive item") var ErrRecursive = errors.New("recursive item")
@ -40,7 +46,7 @@ type deserContext struct {
func Serialize(item Item) ([]byte, error) { func Serialize(item Item) ([]byte, error) {
sc := serContext{ sc := serContext{
allowInvalid: false, allowInvalid: false,
seen: make(map[Item]sliceNoPointer), seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
} }
err := sc.serialize(item) err := sc.serialize(item)
if err != nil { if err != nil {
@ -69,7 +75,7 @@ func EncodeBinary(item Item, w *io.BinWriter) {
func EncodeBinaryProtected(item Item, w *io.BinWriter) { func EncodeBinaryProtected(item Item, w *io.BinWriter) {
sc := serContext{ sc := serContext{
allowInvalid: true, allowInvalid: true,
seen: make(map[Item]sliceNoPointer), seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
} }
err := sc.serialize(item) err := sc.serialize(item)
if err != nil { if err != nil {
@ -79,6 +85,18 @@ func EncodeBinaryProtected(item Item, w *io.BinWriter) {
w.WriteBytes(sc.data) w.WriteBytes(sc.data)
} }
func (w *serContext) writeArray(item Item, arr []Item, start int) error {
w.seen[item] = sliceNoPointer{}
w.appendVarUint(uint64(len(arr)))
for i := range arr {
if err := w.serialize(arr[i]); err != nil {
return err
}
}
w.seen[item] = sliceNoPointer{start, len(w.data)}
return nil
}
func (w *serContext) serialize(item Item) error { func (w *serContext) serialize(item Item) error {
if v, ok := w.seen[item]; ok { if v, ok := w.seen[item]; ok {
if v.start == v.end { if v.start == v.end {
@ -121,28 +139,20 @@ func (w *serContext) serialize(item Item) error {
} else { } else {
return fmt.Errorf("%w: Interop", ErrUnserializable) return fmt.Errorf("%w: Interop", ErrUnserializable)
} }
case *Array, *Struct: case *Array:
w.seen[item] = sliceNoPointer{} w.data = append(w.data, byte(ArrayT))
if err := w.writeArray(item, t.value, start); err != nil {
_, isArray := t.(*Array) return err
if isArray {
w.data = append(w.data, byte(ArrayT))
} else {
w.data = append(w.data, byte(StructT))
} }
case *Struct:
arr := t.Value().([]Item) w.data = append(w.data, byte(StructT))
w.appendVarUint(uint64(len(arr))) if err := w.writeArray(item, t.value, start); err != nil {
for i := range arr { return err
if err := w.serialize(arr[i]); err != nil {
return err
}
} }
w.seen[item] = sliceNoPointer{start, len(w.data)}
case *Map: case *Map:
w.seen[item] = sliceNoPointer{} w.seen[item] = sliceNoPointer{}
elems := t.Value().([]MapElement) elems := t.value
w.data = append(w.data, byte(MapT)) w.data = append(w.data, byte(MapT))
w.appendVarUint(uint64(len(elems))) w.appendVarUint(uint64(len(elems)))
for i := range elems { for i := range elems {

View file

@ -90,6 +90,8 @@ type VM struct {
invTree *InvocationTree invTree *InvocationTree
} }
var bigOne = big.NewInt(1)
// New returns a new VM object ready to load AVM bytecode scripts. // New returns a new VM object ready to load AVM bytecode scripts.
func New() *VM { func New() *VM {
return NewWithTrigger(trigger.Application) return NewWithTrigger(trigger.Application)
@ -105,6 +107,7 @@ func NewWithTrigger(t trigger.Type) *VM {
} }
initStack(&vm.istack, "invocation", nil) initStack(&vm.istack, "invocation", nil)
vm.istack.elems = make([]Element, 0, 8) // Most of invocations use one-two contracts, but they're likely to have internal calls.
vm.estack = newStack("evaluation", &vm.refs) vm.estack = newStack("evaluation", &vm.refs)
return vm return vm
} }
@ -310,13 +313,15 @@ func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160
// It should be used for calling from native contracts. // It should be used for calling from native contracts.
func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160, func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160,
hash util.Uint160, f callflag.CallFlag, rvcount int, offset int) { hash util.Uint160, f callflag.CallFlag, rvcount int, offset int) {
var sl slot
v.checkInvocationStackSize() v.checkInvocationStackSize()
ctx := NewContextWithParams(b, rvcount, offset) ctx := NewContextWithParams(b, rvcount, offset)
v.estack = newStack("evaluation", &v.refs) v.estack = newStack("evaluation", &v.refs)
ctx.estack = v.estack ctx.estack = v.estack
initStack(&ctx.tryStack, "exception", nil) initStack(&ctx.tryStack, "exception", nil)
ctx.callFlag = f ctx.callFlag = f
ctx.static = newSlot(&v.refs) ctx.static = &sl
ctx.scriptHash = hash ctx.scriptHash = hash
ctx.callingScriptHash = caller ctx.callingScriptHash = caller
ctx.NEF = exe ctx.NEF = exe
@ -615,13 +620,13 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
panic("zero argument") panic("zero argument")
} }
if parameter[0] > 0 { if parameter[0] > 0 {
ctx.local = v.newSlot(int(parameter[0])) ctx.local.init(int(parameter[0]))
} }
if parameter[1] > 0 { if parameter[1] > 0 {
sz := int(parameter[1]) sz := int(parameter[1])
ctx.arguments = v.newSlot(sz) ctx.arguments.init(sz)
for i := 0; i < sz; i++ { for i := 0; i < sz; i++ {
ctx.arguments.Set(i, v.estack.Pop().Item()) ctx.arguments.Set(i, v.estack.Pop().Item(), &v.refs)
} }
} }
@ -635,11 +640,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.STSFLD0, opcode.STSFLD1, opcode.STSFLD2, opcode.STSFLD3, opcode.STSFLD4, opcode.STSFLD5, opcode.STSFLD6: case opcode.STSFLD0, opcode.STSFLD1, opcode.STSFLD2, opcode.STSFLD3, opcode.STSFLD4, opcode.STSFLD5, opcode.STSFLD6:
item := v.estack.Pop().Item() item := v.estack.Pop().Item()
ctx.static.Set(int(op-opcode.STSFLD0), item) ctx.static.Set(int(op-opcode.STSFLD0), item, &v.refs)
case opcode.STSFLD: case opcode.STSFLD:
item := v.estack.Pop().Item() item := v.estack.Pop().Item()
ctx.static.Set(int(parameter[0]), item) ctx.static.Set(int(parameter[0]), item, &v.refs)
case opcode.LDLOC0, opcode.LDLOC1, opcode.LDLOC2, opcode.LDLOC3, opcode.LDLOC4, opcode.LDLOC5, opcode.LDLOC6: case opcode.LDLOC0, opcode.LDLOC1, opcode.LDLOC2, opcode.LDLOC3, opcode.LDLOC4, opcode.LDLOC5, opcode.LDLOC6:
item := ctx.local.Get(int(op - opcode.LDLOC0)) item := ctx.local.Get(int(op - opcode.LDLOC0))
@ -651,11 +656,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.STLOC0, opcode.STLOC1, opcode.STLOC2, opcode.STLOC3, opcode.STLOC4, opcode.STLOC5, opcode.STLOC6: case opcode.STLOC0, opcode.STLOC1, opcode.STLOC2, opcode.STLOC3, opcode.STLOC4, opcode.STLOC5, opcode.STLOC6:
item := v.estack.Pop().Item() item := v.estack.Pop().Item()
ctx.local.Set(int(op-opcode.STLOC0), item) ctx.local.Set(int(op-opcode.STLOC0), item, &v.refs)
case opcode.STLOC: case opcode.STLOC:
item := v.estack.Pop().Item() item := v.estack.Pop().Item()
ctx.local.Set(int(parameter[0]), item) ctx.local.Set(int(parameter[0]), item, &v.refs)
case opcode.LDARG0, opcode.LDARG1, opcode.LDARG2, opcode.LDARG3, opcode.LDARG4, opcode.LDARG5, opcode.LDARG6: case opcode.LDARG0, opcode.LDARG1, opcode.LDARG2, opcode.LDARG3, opcode.LDARG4, opcode.LDARG5, opcode.LDARG6:
item := ctx.arguments.Get(int(op - opcode.LDARG0)) item := ctx.arguments.Get(int(op - opcode.LDARG0))
@ -667,11 +672,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.STARG0, opcode.STARG1, opcode.STARG2, opcode.STARG3, opcode.STARG4, opcode.STARG5, opcode.STARG6: case opcode.STARG0, opcode.STARG1, opcode.STARG2, opcode.STARG3, opcode.STARG4, opcode.STARG5, opcode.STARG6:
item := v.estack.Pop().Item() item := v.estack.Pop().Item()
ctx.arguments.Set(int(op-opcode.STARG0), item) ctx.arguments.Set(int(op-opcode.STARG0), item, &v.refs)
case opcode.STARG: case opcode.STARG:
item := v.estack.Pop().Item() item := v.estack.Pop().Item()
ctx.arguments.Set(int(parameter[0]), item) ctx.arguments.Set(int(parameter[0]), item, &v.refs)
case opcode.NEWBUFFER: case opcode.NEWBUFFER:
n := toInt(v.estack.Pop().BigInt()) n := toInt(v.estack.Pop().BigInt())
@ -887,12 +892,12 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.INC: case opcode.INC:
x := v.estack.Pop().BigInt() x := v.estack.Pop().BigInt()
a := new(big.Int).Add(x, big.NewInt(1)) a := new(big.Int).Add(x, bigOne)
v.estack.PushItem(stackitem.NewBigInteger(a)) v.estack.PushItem(stackitem.NewBigInteger(a))
case opcode.DEC: case opcode.DEC:
x := v.estack.Pop().BigInt() x := v.estack.Pop().BigInt()
a := new(big.Int).Sub(x, big.NewInt(1)) a := new(big.Int).Sub(x, bigOne)
v.estack.PushItem(stackitem.NewBigInteger(a)) v.estack.PushItem(stackitem.NewBigInteger(a))
case opcode.ADD: case opcode.ADD:
@ -1527,14 +1532,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
func (v *VM) unloadContext(ctx *Context) { func (v *VM) unloadContext(ctx *Context) {
if ctx.local != nil { if ctx.local != nil {
ctx.local.Clear() ctx.local.Clear(&v.refs)
} }
if ctx.arguments != nil { if ctx.arguments != nil {
ctx.arguments.Clear() ctx.arguments.Clear(&v.refs)
} }
currCtx := v.Context() currCtx := v.Context()
if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static {
ctx.static.Clear() ctx.static.Clear(&v.refs)
} }
} }