diff --git a/pkg/core/native/native_gas.go b/pkg/core/native/native_gas.go index d02b672e5..9347473f0 100644 --- a/pkg/core/native/native_gas.go +++ b/pkg/core/native/native_gas.go @@ -51,19 +51,19 @@ func newGAS(init int64, p2pSigExtensionsEnabled bool) *GAS { return g } -func (g *GAS) increaseBalance(_ *interop.Context, _ util.Uint160, si *state.StorageItem, amount *big.Int, checkBal *big.Int) error { +func (g *GAS) increaseBalance(_ *interop.Context, _ util.Uint160, si *state.StorageItem, amount *big.Int, checkBal *big.Int) (func(), error) { acc, err := state.NEP17BalanceFromBytes(*si) if err != nil { - return err + return nil, err } if sign := amount.Sign(); sign == 0 { // Requested self-transfer amount can be higher than actual balance. if checkBal != nil && acc.Balance.Cmp(checkBal) < 0 { err = errors.New("insufficient funds") } - return err + return nil, err } else if sign == -1 && acc.Balance.CmpAbs(amount) == -1 { - return errors.New("insufficient funds") + return nil, errors.New("insufficient funds") } acc.Balance.Add(&acc.Balance, amount) if acc.Balance.Sign() != 0 { @@ -71,7 +71,7 @@ func (g *GAS) increaseBalance(_ *interop.Context, _ util.Uint160, si *state.Stor } else { *si = nil } - return nil + return nil, nil } func (g *GAS) balanceFromBytes(si *state.StorageItem) (*big.Int, error) { diff --git a/pkg/core/native/native_neo.go b/pkg/core/native/native_neo.go index cadd03570..7be0b02d5 100644 --- a/pkg/core/native/native_neo.go +++ b/pkg/core/native/native_neo.go @@ -420,28 +420,34 @@ func (n *NEO) getGASPerVote(d *dao.Simple, key []byte, indexes []uint32) []big.I return reward } -func (n *NEO) increaseBalance(ic *interop.Context, h util.Uint160, si *state.StorageItem, amount *big.Int, checkBal *big.Int) error { +func (n *NEO) increaseBalance(ic *interop.Context, h util.Uint160, si *state.StorageItem, amount *big.Int, checkBal *big.Int) (func(), error) { + var postF func() + acc, err := state.NEOBalanceFromBytes(*si) if err != nil { - return err + return nil, err } if (amount.Sign() == -1 && acc.Balance.CmpAbs(amount) == -1) || (amount.Sign() == 0 && checkBal != nil && acc.Balance.Cmp(checkBal) == -1) { - return errors.New("insufficient funds") + return nil, errors.New("insufficient funds") } - if err := n.distributeGas(ic, h, acc); err != nil { - return err + newGas, err := n.distributeGas(ic, acc) + if err != nil { + return nil, err + } + if newGas != nil { // Can be if it was already distributed in the same block. + postF = func() { n.GAS.mint(ic, h, newGas, true) } } if amount.Sign() == 0 { *si = acc.Bytes() - return nil + return postF, nil } if err := n.ModifyAccountVotes(acc, ic.DAO, amount, false); err != nil { - return err + return nil, err } if acc.VoteTo != nil { if err := n.modifyVoterTurnout(ic.DAO, amount); err != nil { - return err + return nil, err } } acc.Balance.Add(&acc.Balance, amount) @@ -450,7 +456,7 @@ func (n *NEO) increaseBalance(ic *interop.Context, h util.Uint160, si *state.Sto } else { *si = nil } - return nil + return postF, nil } func (n *NEO) balanceFromBytes(si *state.StorageItem) (*big.Int, error) { @@ -461,23 +467,17 @@ func (n *NEO) balanceFromBytes(si *state.StorageItem) (*big.Int, error) { return &acc.Balance, err } -func (n *NEO) distributeGas(ic *interop.Context, h util.Uint160, acc *state.NEOBalance) error { +func (n *NEO) distributeGas(ic *interop.Context, acc *state.NEOBalance) (*big.Int, error) { if ic.Block == nil || ic.Block.Index == 0 || ic.Block.Index == acc.BalanceHeight { - return nil + return nil, nil } gen, err := n.calculateBonus(ic.DAO, acc.VoteTo, &acc.Balance, acc.BalanceHeight, ic.Block.Index) if err != nil { - return err + return nil, err } acc.BalanceHeight = ic.Block.Index - // Must store acc before GAS distribution to fix acc's BalanceHeight value in the storage for - // further acc's queries from `onNEP17Payment` if so, see https://github.com/nspcc-dev/neo-go/pull/2181. - key := makeAccountKey(h) - ic.DAO.PutStorageItem(n.ID, key, acc.Bytes()) - - n.GAS.mint(ic, h, gen, true) - return nil + return gen, nil } func (n *NEO) unclaimedGas(ic *interop.Context, args []stackitem.Item) stackitem.Item { @@ -808,7 +808,8 @@ func (n *NEO) VoteInternal(ic *interop.Context, h util.Uint160, pub *keys.Public return err } } - if err := n.distributeGas(ic, h, acc); err != nil { + newGas, err := n.distributeGas(ic, acc) + if err != nil { return err } if err := n.ModifyAccountVotes(acc, ic.DAO, new(big.Int).Neg(&acc.Balance), false); err != nil { @@ -819,6 +820,9 @@ func (n *NEO) VoteInternal(ic *interop.Context, h util.Uint160, pub *keys.Public return err } ic.DAO.PutStorageItem(n.ID, key, acc.Bytes()) + if newGas != nil { // Can be if it was already distributed in the same block. + n.GAS.mint(ic, h, newGas, true) + } return nil } diff --git a/pkg/core/native/native_nep17.go b/pkg/core/native/native_nep17.go index 021364b5d..436bdcd64 100644 --- a/pkg/core/native/native_nep17.go +++ b/pkg/core/native/native_nep17.go @@ -33,7 +33,7 @@ type nep17TokenNative struct { symbol string decimals int64 factor int64 - incBalance func(*interop.Context, util.Uint160, *state.StorageItem, *big.Int, *big.Int) error + incBalance func(*interop.Context, util.Uint160, *state.StorageItem, *big.Int, *big.Int) (func(), error) balFromBytes func(item *state.StorageItem) (*big.Int, error) } @@ -128,7 +128,18 @@ func addrToStackItem(u *util.Uint160) stackitem.Item { } func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint160, amount *big.Int, - data stackitem.Item, callOnPayment bool) { + data stackitem.Item, callOnPayment bool, postCalls ...func()) { + var skipPostCalls bool + defer func() { + if skipPostCalls { + return + } + for _, f := range postCalls { + if f != nil { + f() + } + } + }() c.emitTransfer(ic, from, to, amount) if to == nil || !callOnPayment { return @@ -148,6 +159,7 @@ func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint data, } if err := contract.CallFromNative(ic, c.Hash, cs, manifest.MethodOnNEP17Payment, args, false); err != nil { + skipPostCalls = true panic(err) } } @@ -167,37 +179,39 @@ func (c *nep17TokenNative) emitTransfer(ic *interop.Context, from, to *util.Uint // updateAccBalance adds specified amount to the acc's balance. If requiredBalance // is set and amount is 0, then acc's balance is checked against requiredBalance. -func (c *nep17TokenNative) updateAccBalance(ic *interop.Context, acc util.Uint160, amount *big.Int, requiredBalance *big.Int) error { +func (c *nep17TokenNative) updateAccBalance(ic *interop.Context, acc util.Uint160, amount *big.Int, requiredBalance *big.Int) (func(), error) { key := makeAccountKey(acc) si := ic.DAO.GetStorageItem(c.ID, key) if si == nil { if amount.Sign() < 0 { - return errors.New("insufficient funds") + return nil, errors.New("insufficient funds") } if amount.Sign() == 0 { // it's OK to transfer 0 if the balance 0, no need to put si to the storage - return nil + return nil, nil } si = state.StorageItem{} } - err := c.incBalance(ic, acc, &si, amount, requiredBalance) + postF, err := c.incBalance(ic, acc, &si, amount, requiredBalance) if err != nil { if si != nil && amount.Sign() <= 0 { ic.DAO.PutStorageItem(c.ID, key, si) } - return err + return nil, err } if si == nil { ic.DAO.DeleteStorageItem(c.ID, key) } else { ic.DAO.PutStorageItem(c.ID, key, si) } - return nil + return postF, nil } // TransferInternal transfers NEO between accounts. func (c *nep17TokenNative) TransferInternal(ic *interop.Context, from, to util.Uint160, amount *big.Int, data stackitem.Item) error { + var postF1, postF2 func() + if amount.Sign() == -1 { return errors.New("negative amount") } @@ -218,17 +232,20 @@ func (c *nep17TokenNative) TransferInternal(ic *interop.Context, from, to util.U } else { inc = new(big.Int).Neg(inc) } - if err := c.updateAccBalance(ic, from, inc, amount); err != nil { + + postF1, err := c.updateAccBalance(ic, from, inc, amount) + if err != nil { return err } if !isEmpty { - if err := c.updateAccBalance(ic, to, amount, nil); err != nil { + postF2, err = c.updateAccBalance(ic, to, amount, nil) + if err != nil { return err } } - c.postTransfer(ic, &from, &to, amount, data, true) + c.postTransfer(ic, &from, &to, amount, data, true, postF1, postF2) return nil } @@ -254,8 +271,8 @@ func (c *nep17TokenNative) mint(ic *interop.Context, h util.Uint160, amount *big if amount.Sign() == 0 { return } - c.addTokens(ic, h, amount) - c.postTransfer(ic, nil, &h, amount, stackitem.Null{}, callOnPayment) + postF := c.addTokens(ic, h, amount) + c.postTransfer(ic, nil, &h, amount, stackitem.Null{}, callOnPayment, postF) } func (c *nep17TokenNative) burn(ic *interop.Context, h util.Uint160, amount *big.Int) { @@ -263,14 +280,14 @@ func (c *nep17TokenNative) burn(ic *interop.Context, h util.Uint160, amount *big return } amount.Neg(amount) - c.addTokens(ic, h, amount) + postF := c.addTokens(ic, h, amount) amount.Neg(amount) - c.postTransfer(ic, &h, nil, amount, stackitem.Null{}, false) + c.postTransfer(ic, &h, nil, amount, stackitem.Null{}, false, postF) } -func (c *nep17TokenNative) addTokens(ic *interop.Context, h util.Uint160, amount *big.Int) { +func (c *nep17TokenNative) addTokens(ic *interop.Context, h util.Uint160, amount *big.Int) func() { if amount.Sign() == 0 { - return + return nil } key := makeAccountKey(h) @@ -278,7 +295,8 @@ func (c *nep17TokenNative) addTokens(ic *interop.Context, h util.Uint160, amount if si == nil { si = state.StorageItem{} } - if err := c.incBalance(ic, h, &si, amount, nil); err != nil { + postF, err := c.incBalance(ic, h, &si, amount, nil) + if err != nil { panic(err) } if si == nil { @@ -290,6 +308,7 @@ func (c *nep17TokenNative) addTokens(ic *interop.Context, h util.Uint160, amount buf, supply := c.getTotalSupply(ic.DAO) supply.Add(supply, amount) c.saveTotalSupply(ic.DAO, buf, supply) + return postF } func newDescriptor(name string, ret smartcontract.ParamType, ps ...manifest.Parameter) *manifest.Method { diff --git a/pkg/core/native/native_test/neo_test.go b/pkg/core/native/native_test/neo_test.go index 6da70d52e..63a0fe945 100644 --- a/pkg/core/native/native_test/neo_test.go +++ b/pkg/core/native/native_test/neo_test.go @@ -291,7 +291,7 @@ func TestNEO_TransferOnPayment(t *testing.T) { h := neoValidatorsInvoker.Invoke(t, true, "transfer", e.Validator.ScriptHash(), cs.Hash, amount, nil) aer := e.GetTxExecResult(t, h) require.Equal(t, 3, len(aer.Events)) // transfer + GAS claim for sender + onPayment - e.CheckTxNotificationEvent(t, h, 2, state.NotificationEvent{ + e.CheckTxNotificationEvent(t, h, 1, state.NotificationEvent{ ScriptHash: cs.Hash, Name: "LastPayment", Item: stackitem.NewArray([]stackitem.Item{ @@ -305,17 +305,7 @@ func TestNEO_TransferOnPayment(t *testing.T) { h = neoValidatorsInvoker.Invoke(t, true, "transfer", e.Validator.ScriptHash(), cs.Hash, amount, nil) aer = e.GetTxExecResult(t, h) require.Equal(t, 5, len(aer.Events)) // Now we must also have GAS claim for contract and corresponding `onPayment`. - e.CheckTxNotificationEvent(t, h, 2, state.NotificationEvent{ // onPayment for GAS claim - ScriptHash: cs.Hash, - Name: "LastPayment", - Item: stackitem.NewArray([]stackitem.Item{ - stackitem.NewByteArray(e.NativeHash(t, nativenames.Gas).BytesBE()), - stackitem.Null{}, - stackitem.NewBigInteger(big.NewInt(1)), - stackitem.Null{}, - }), - }) - e.CheckTxNotificationEvent(t, h, 4, state.NotificationEvent{ // onPayment for NEO transfer + e.CheckTxNotificationEvent(t, h, 1, state.NotificationEvent{ // onPayment for NEO transfer ScriptHash: cs.Hash, Name: "LastPayment", Item: stackitem.NewArray([]stackitem.Item{ @@ -325,6 +315,16 @@ func TestNEO_TransferOnPayment(t *testing.T) { stackitem.Null{}, }), }) + e.CheckTxNotificationEvent(t, h, 4, state.NotificationEvent{ // onPayment for GAS claim + ScriptHash: cs.Hash, + Name: "LastPayment", + Item: stackitem.NewArray([]stackitem.Item{ + stackitem.NewByteArray(e.NativeHash(t, nativenames.Gas).BytesBE()), + stackitem.Null{}, + stackitem.NewBigInteger(big.NewInt(1)), + stackitem.Null{}, + }), + }) } func TestNEO_Roundtrip(t *testing.T) { @@ -406,7 +406,7 @@ func TestNEO_TransferZeroWithNonZeroBalance(t *testing.T) { aer, err := e.Chain.GetAppExecResults(h, trigger.Application) require.NoError(t, err) require.Equal(t, 2, len(aer[0].Events)) // roundtrip + GAS claim - require.Equal(t, stackitem.NewBigInteger(big.NewInt(0)), aer[0].Events[1].Item.Value().([]stackitem.Item)[2]) // amount is 0 + require.Equal(t, stackitem.NewBigInteger(big.NewInt(0)), aer[0].Events[0].Item.Value().([]stackitem.Item)[2]) // amount is 0 // check balance wasn't changed and height was updated updatedBalance, updatedHeight := e.Chain.GetGoverningTokenBalance(acc.ScriptHash()) require.Equal(t, initialBalance, updatedBalance)