neo: mint GAS after NEO transfer processing

See neo-project/neo#2734.
This commit is contained in:
Roman Khimov 2022-05-13 14:49:41 +03:00
parent 2f037f1e08
commit 85fe111aea
4 changed files with 79 additions and 56 deletions

View file

@ -51,19 +51,19 @@ func newGAS(init int64, p2pSigExtensionsEnabled bool) *GAS {
return g return g
} }
func (g *GAS) increaseBalance(_ *interop.Context, _ util.Uint160, si *state.StorageItem, amount *big.Int, checkBal *big.Int) error { func (g *GAS) increaseBalance(_ *interop.Context, _ util.Uint160, si *state.StorageItem, amount *big.Int, checkBal *big.Int) (func(), error) {
acc, err := state.NEP17BalanceFromBytes(*si) acc, err := state.NEP17BalanceFromBytes(*si)
if err != nil { if err != nil {
return err return nil, err
} }
if sign := amount.Sign(); sign == 0 { if sign := amount.Sign(); sign == 0 {
// Requested self-transfer amount can be higher than actual balance. // Requested self-transfer amount can be higher than actual balance.
if checkBal != nil && acc.Balance.Cmp(checkBal) < 0 { if checkBal != nil && acc.Balance.Cmp(checkBal) < 0 {
err = errors.New("insufficient funds") err = errors.New("insufficient funds")
} }
return err return nil, err
} else if sign == -1 && acc.Balance.CmpAbs(amount) == -1 { } else if sign == -1 && acc.Balance.CmpAbs(amount) == -1 {
return errors.New("insufficient funds") return nil, errors.New("insufficient funds")
} }
acc.Balance.Add(&acc.Balance, amount) acc.Balance.Add(&acc.Balance, amount)
if acc.Balance.Sign() != 0 { if acc.Balance.Sign() != 0 {
@ -71,7 +71,7 @@ func (g *GAS) increaseBalance(_ *interop.Context, _ util.Uint160, si *state.Stor
} else { } else {
*si = nil *si = nil
} }
return nil return nil, nil
} }
func (g *GAS) balanceFromBytes(si *state.StorageItem) (*big.Int, error) { func (g *GAS) balanceFromBytes(si *state.StorageItem) (*big.Int, error) {

View file

@ -420,28 +420,34 @@ func (n *NEO) getGASPerVote(d *dao.Simple, key []byte, indexes []uint32) []big.I
return reward return reward
} }
func (n *NEO) increaseBalance(ic *interop.Context, h util.Uint160, si *state.StorageItem, amount *big.Int, checkBal *big.Int) error { func (n *NEO) increaseBalance(ic *interop.Context, h util.Uint160, si *state.StorageItem, amount *big.Int, checkBal *big.Int) (func(), error) {
var postF func()
acc, err := state.NEOBalanceFromBytes(*si) acc, err := state.NEOBalanceFromBytes(*si)
if err != nil { if err != nil {
return err return nil, err
} }
if (amount.Sign() == -1 && acc.Balance.CmpAbs(amount) == -1) || if (amount.Sign() == -1 && acc.Balance.CmpAbs(amount) == -1) ||
(amount.Sign() == 0 && checkBal != nil && acc.Balance.Cmp(checkBal) == -1) { (amount.Sign() == 0 && checkBal != nil && acc.Balance.Cmp(checkBal) == -1) {
return errors.New("insufficient funds") return nil, errors.New("insufficient funds")
} }
if err := n.distributeGas(ic, h, acc); err != nil { newGas, err := n.distributeGas(ic, acc)
return err if err != nil {
return nil, err
}
if newGas != nil { // Can be if it was already distributed in the same block.
postF = func() { n.GAS.mint(ic, h, newGas, true) }
} }
if amount.Sign() == 0 { if amount.Sign() == 0 {
*si = acc.Bytes() *si = acc.Bytes()
return nil return postF, nil
} }
if err := n.ModifyAccountVotes(acc, ic.DAO, amount, false); err != nil { if err := n.ModifyAccountVotes(acc, ic.DAO, amount, false); err != nil {
return err return nil, err
} }
if acc.VoteTo != nil { if acc.VoteTo != nil {
if err := n.modifyVoterTurnout(ic.DAO, amount); err != nil { if err := n.modifyVoterTurnout(ic.DAO, amount); err != nil {
return err return nil, err
} }
} }
acc.Balance.Add(&acc.Balance, amount) acc.Balance.Add(&acc.Balance, amount)
@ -450,7 +456,7 @@ func (n *NEO) increaseBalance(ic *interop.Context, h util.Uint160, si *state.Sto
} else { } else {
*si = nil *si = nil
} }
return nil return postF, nil
} }
func (n *NEO) balanceFromBytes(si *state.StorageItem) (*big.Int, error) { func (n *NEO) balanceFromBytes(si *state.StorageItem) (*big.Int, error) {
@ -461,23 +467,17 @@ func (n *NEO) balanceFromBytes(si *state.StorageItem) (*big.Int, error) {
return &acc.Balance, err return &acc.Balance, err
} }
func (n *NEO) distributeGas(ic *interop.Context, h util.Uint160, acc *state.NEOBalance) error { func (n *NEO) distributeGas(ic *interop.Context, acc *state.NEOBalance) (*big.Int, error) {
if ic.Block == nil || ic.Block.Index == 0 || ic.Block.Index == acc.BalanceHeight { if ic.Block == nil || ic.Block.Index == 0 || ic.Block.Index == acc.BalanceHeight {
return nil return nil, nil
} }
gen, err := n.calculateBonus(ic.DAO, acc.VoteTo, &acc.Balance, acc.BalanceHeight, ic.Block.Index) gen, err := n.calculateBonus(ic.DAO, acc.VoteTo, &acc.Balance, acc.BalanceHeight, ic.Block.Index)
if err != nil { if err != nil {
return err return nil, err
} }
acc.BalanceHeight = ic.Block.Index acc.BalanceHeight = ic.Block.Index
// Must store acc before GAS distribution to fix acc's BalanceHeight value in the storage for return gen, nil
// further acc's queries from `onNEP17Payment` if so, see https://github.com/nspcc-dev/neo-go/pull/2181.
key := makeAccountKey(h)
ic.DAO.PutStorageItem(n.ID, key, acc.Bytes())
n.GAS.mint(ic, h, gen, true)
return nil
} }
func (n *NEO) unclaimedGas(ic *interop.Context, args []stackitem.Item) stackitem.Item { func (n *NEO) unclaimedGas(ic *interop.Context, args []stackitem.Item) stackitem.Item {
@ -808,7 +808,8 @@ func (n *NEO) VoteInternal(ic *interop.Context, h util.Uint160, pub *keys.Public
return err return err
} }
} }
if err := n.distributeGas(ic, h, acc); err != nil { newGas, err := n.distributeGas(ic, acc)
if err != nil {
return err return err
} }
if err := n.ModifyAccountVotes(acc, ic.DAO, new(big.Int).Neg(&acc.Balance), false); err != nil { if err := n.ModifyAccountVotes(acc, ic.DAO, new(big.Int).Neg(&acc.Balance), false); err != nil {
@ -819,6 +820,9 @@ func (n *NEO) VoteInternal(ic *interop.Context, h util.Uint160, pub *keys.Public
return err return err
} }
ic.DAO.PutStorageItem(n.ID, key, acc.Bytes()) ic.DAO.PutStorageItem(n.ID, key, acc.Bytes())
if newGas != nil { // Can be if it was already distributed in the same block.
n.GAS.mint(ic, h, newGas, true)
}
return nil return nil
} }

View file

@ -33,7 +33,7 @@ type nep17TokenNative struct {
symbol string symbol string
decimals int64 decimals int64
factor int64 factor int64
incBalance func(*interop.Context, util.Uint160, *state.StorageItem, *big.Int, *big.Int) error incBalance func(*interop.Context, util.Uint160, *state.StorageItem, *big.Int, *big.Int) (func(), error)
balFromBytes func(item *state.StorageItem) (*big.Int, error) balFromBytes func(item *state.StorageItem) (*big.Int, error)
} }
@ -128,7 +128,18 @@ func addrToStackItem(u *util.Uint160) stackitem.Item {
} }
func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint160, amount *big.Int, func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint160, amount *big.Int,
data stackitem.Item, callOnPayment bool) { data stackitem.Item, callOnPayment bool, postCalls ...func()) {
var skipPostCalls bool
defer func() {
if skipPostCalls {
return
}
for _, f := range postCalls {
if f != nil {
f()
}
}
}()
c.emitTransfer(ic, from, to, amount) c.emitTransfer(ic, from, to, amount)
if to == nil || !callOnPayment { if to == nil || !callOnPayment {
return return
@ -148,6 +159,7 @@ func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint
data, data,
} }
if err := contract.CallFromNative(ic, c.Hash, cs, manifest.MethodOnNEP17Payment, args, false); err != nil { if err := contract.CallFromNative(ic, c.Hash, cs, manifest.MethodOnNEP17Payment, args, false); err != nil {
skipPostCalls = true
panic(err) panic(err)
} }
} }
@ -167,37 +179,39 @@ func (c *nep17TokenNative) emitTransfer(ic *interop.Context, from, to *util.Uint
// updateAccBalance adds specified amount to the acc's balance. If requiredBalance // updateAccBalance adds specified amount to the acc's balance. If requiredBalance
// is set and amount is 0, then acc's balance is checked against requiredBalance. // is set and amount is 0, then acc's balance is checked against requiredBalance.
func (c *nep17TokenNative) updateAccBalance(ic *interop.Context, acc util.Uint160, amount *big.Int, requiredBalance *big.Int) error { func (c *nep17TokenNative) updateAccBalance(ic *interop.Context, acc util.Uint160, amount *big.Int, requiredBalance *big.Int) (func(), error) {
key := makeAccountKey(acc) key := makeAccountKey(acc)
si := ic.DAO.GetStorageItem(c.ID, key) si := ic.DAO.GetStorageItem(c.ID, key)
if si == nil { if si == nil {
if amount.Sign() < 0 { if amount.Sign() < 0 {
return errors.New("insufficient funds") return nil, errors.New("insufficient funds")
} }
if amount.Sign() == 0 { if amount.Sign() == 0 {
// it's OK to transfer 0 if the balance 0, no need to put si to the storage // it's OK to transfer 0 if the balance 0, no need to put si to the storage
return nil return nil, nil
} }
si = state.StorageItem{} si = state.StorageItem{}
} }
err := c.incBalance(ic, acc, &si, amount, requiredBalance) postF, err := c.incBalance(ic, acc, &si, amount, requiredBalance)
if err != nil { if err != nil {
if si != nil && amount.Sign() <= 0 { if si != nil && amount.Sign() <= 0 {
ic.DAO.PutStorageItem(c.ID, key, si) ic.DAO.PutStorageItem(c.ID, key, si)
} }
return err return nil, err
} }
if si == nil { if si == nil {
ic.DAO.DeleteStorageItem(c.ID, key) ic.DAO.DeleteStorageItem(c.ID, key)
} else { } else {
ic.DAO.PutStorageItem(c.ID, key, si) ic.DAO.PutStorageItem(c.ID, key, si)
} }
return nil return postF, nil
} }
// TransferInternal transfers NEO between accounts. // TransferInternal transfers NEO between accounts.
func (c *nep17TokenNative) TransferInternal(ic *interop.Context, from, to util.Uint160, amount *big.Int, data stackitem.Item) error { func (c *nep17TokenNative) TransferInternal(ic *interop.Context, from, to util.Uint160, amount *big.Int, data stackitem.Item) error {
var postF1, postF2 func()
if amount.Sign() == -1 { if amount.Sign() == -1 {
return errors.New("negative amount") return errors.New("negative amount")
} }
@ -218,17 +232,20 @@ func (c *nep17TokenNative) TransferInternal(ic *interop.Context, from, to util.U
} else { } else {
inc = new(big.Int).Neg(inc) inc = new(big.Int).Neg(inc)
} }
if err := c.updateAccBalance(ic, from, inc, amount); err != nil {
postF1, err := c.updateAccBalance(ic, from, inc, amount)
if err != nil {
return err return err
} }
if !isEmpty { if !isEmpty {
if err := c.updateAccBalance(ic, to, amount, nil); err != nil { postF2, err = c.updateAccBalance(ic, to, amount, nil)
if err != nil {
return err return err
} }
} }
c.postTransfer(ic, &from, &to, amount, data, true) c.postTransfer(ic, &from, &to, amount, data, true, postF1, postF2)
return nil return nil
} }
@ -254,8 +271,8 @@ func (c *nep17TokenNative) mint(ic *interop.Context, h util.Uint160, amount *big
if amount.Sign() == 0 { if amount.Sign() == 0 {
return return
} }
c.addTokens(ic, h, amount) postF := c.addTokens(ic, h, amount)
c.postTransfer(ic, nil, &h, amount, stackitem.Null{}, callOnPayment) c.postTransfer(ic, nil, &h, amount, stackitem.Null{}, callOnPayment, postF)
} }
func (c *nep17TokenNative) burn(ic *interop.Context, h util.Uint160, amount *big.Int) { func (c *nep17TokenNative) burn(ic *interop.Context, h util.Uint160, amount *big.Int) {
@ -263,14 +280,14 @@ func (c *nep17TokenNative) burn(ic *interop.Context, h util.Uint160, amount *big
return return
} }
amount.Neg(amount) amount.Neg(amount)
c.addTokens(ic, h, amount) postF := c.addTokens(ic, h, amount)
amount.Neg(amount) amount.Neg(amount)
c.postTransfer(ic, &h, nil, amount, stackitem.Null{}, false) c.postTransfer(ic, &h, nil, amount, stackitem.Null{}, false, postF)
} }
func (c *nep17TokenNative) addTokens(ic *interop.Context, h util.Uint160, amount *big.Int) { func (c *nep17TokenNative) addTokens(ic *interop.Context, h util.Uint160, amount *big.Int) func() {
if amount.Sign() == 0 { if amount.Sign() == 0 {
return return nil
} }
key := makeAccountKey(h) key := makeAccountKey(h)
@ -278,7 +295,8 @@ func (c *nep17TokenNative) addTokens(ic *interop.Context, h util.Uint160, amount
if si == nil { if si == nil {
si = state.StorageItem{} si = state.StorageItem{}
} }
if err := c.incBalance(ic, h, &si, amount, nil); err != nil { postF, err := c.incBalance(ic, h, &si, amount, nil)
if err != nil {
panic(err) panic(err)
} }
if si == nil { if si == nil {
@ -290,6 +308,7 @@ func (c *nep17TokenNative) addTokens(ic *interop.Context, h util.Uint160, amount
buf, supply := c.getTotalSupply(ic.DAO) buf, supply := c.getTotalSupply(ic.DAO)
supply.Add(supply, amount) supply.Add(supply, amount)
c.saveTotalSupply(ic.DAO, buf, supply) c.saveTotalSupply(ic.DAO, buf, supply)
return postF
} }
func newDescriptor(name string, ret smartcontract.ParamType, ps ...manifest.Parameter) *manifest.Method { func newDescriptor(name string, ret smartcontract.ParamType, ps ...manifest.Parameter) *manifest.Method {

View file

@ -291,7 +291,7 @@ func TestNEO_TransferOnPayment(t *testing.T) {
h := neoValidatorsInvoker.Invoke(t, true, "transfer", e.Validator.ScriptHash(), cs.Hash, amount, nil) h := neoValidatorsInvoker.Invoke(t, true, "transfer", e.Validator.ScriptHash(), cs.Hash, amount, nil)
aer := e.GetTxExecResult(t, h) aer := e.GetTxExecResult(t, h)
require.Equal(t, 3, len(aer.Events)) // transfer + GAS claim for sender + onPayment require.Equal(t, 3, len(aer.Events)) // transfer + GAS claim for sender + onPayment
e.CheckTxNotificationEvent(t, h, 2, state.NotificationEvent{ e.CheckTxNotificationEvent(t, h, 1, state.NotificationEvent{
ScriptHash: cs.Hash, ScriptHash: cs.Hash,
Name: "LastPayment", Name: "LastPayment",
Item: stackitem.NewArray([]stackitem.Item{ Item: stackitem.NewArray([]stackitem.Item{
@ -305,17 +305,7 @@ func TestNEO_TransferOnPayment(t *testing.T) {
h = neoValidatorsInvoker.Invoke(t, true, "transfer", e.Validator.ScriptHash(), cs.Hash, amount, nil) h = neoValidatorsInvoker.Invoke(t, true, "transfer", e.Validator.ScriptHash(), cs.Hash, amount, nil)
aer = e.GetTxExecResult(t, h) aer = e.GetTxExecResult(t, h)
require.Equal(t, 5, len(aer.Events)) // Now we must also have GAS claim for contract and corresponding `onPayment`. require.Equal(t, 5, len(aer.Events)) // Now we must also have GAS claim for contract and corresponding `onPayment`.
e.CheckTxNotificationEvent(t, h, 2, state.NotificationEvent{ // onPayment for GAS claim e.CheckTxNotificationEvent(t, h, 1, state.NotificationEvent{ // onPayment for NEO transfer
ScriptHash: cs.Hash,
Name: "LastPayment",
Item: stackitem.NewArray([]stackitem.Item{
stackitem.NewByteArray(e.NativeHash(t, nativenames.Gas).BytesBE()),
stackitem.Null{},
stackitem.NewBigInteger(big.NewInt(1)),
stackitem.Null{},
}),
})
e.CheckTxNotificationEvent(t, h, 4, state.NotificationEvent{ // onPayment for NEO transfer
ScriptHash: cs.Hash, ScriptHash: cs.Hash,
Name: "LastPayment", Name: "LastPayment",
Item: stackitem.NewArray([]stackitem.Item{ Item: stackitem.NewArray([]stackitem.Item{
@ -325,6 +315,16 @@ func TestNEO_TransferOnPayment(t *testing.T) {
stackitem.Null{}, stackitem.Null{},
}), }),
}) })
e.CheckTxNotificationEvent(t, h, 4, state.NotificationEvent{ // onPayment for GAS claim
ScriptHash: cs.Hash,
Name: "LastPayment",
Item: stackitem.NewArray([]stackitem.Item{
stackitem.NewByteArray(e.NativeHash(t, nativenames.Gas).BytesBE()),
stackitem.Null{},
stackitem.NewBigInteger(big.NewInt(1)),
stackitem.Null{},
}),
})
} }
func TestNEO_Roundtrip(t *testing.T) { func TestNEO_Roundtrip(t *testing.T) {
@ -406,7 +406,7 @@ func TestNEO_TransferZeroWithNonZeroBalance(t *testing.T) {
aer, err := e.Chain.GetAppExecResults(h, trigger.Application) aer, err := e.Chain.GetAppExecResults(h, trigger.Application)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, len(aer[0].Events)) // roundtrip + GAS claim require.Equal(t, 2, len(aer[0].Events)) // roundtrip + GAS claim
require.Equal(t, stackitem.NewBigInteger(big.NewInt(0)), aer[0].Events[1].Item.Value().([]stackitem.Item)[2]) // amount is 0 require.Equal(t, stackitem.NewBigInteger(big.NewInt(0)), aer[0].Events[0].Item.Value().([]stackitem.Item)[2]) // amount is 0
// check balance wasn't changed and height was updated // check balance wasn't changed and height was updated
updatedBalance, updatedHeight := e.Chain.GetGoverningTokenBalance(acc.ScriptHash()) updatedBalance, updatedHeight := e.Chain.GetGoverningTokenBalance(acc.ScriptHash())
require.Equal(t, initialBalance, updatedBalance) require.Equal(t, initialBalance, updatedBalance)