diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 768c91367..0353c833e 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -946,11 +946,12 @@ func (bc *Blockchain) processNEP17Transfer(cache *dao.Cached, h util.Uint256, b bs.LastUpdatedBlock = b.Index balances.Trackers[id] = bs transfer.Amount = *new(big.Int).Sub(&transfer.Amount, amount) - isBig, err := cache.AppendNEP17Transfer(fromAddr, balances.NextTransferBatch, transfer) + balances.NewBatch, err = cache.AppendNEP17Transfer(fromAddr, + balances.NextTransferBatch, balances.NewBatch, transfer) if err != nil { return } - if isBig { + if balances.NewBatch { balances.NextTransferBatch++ } if err := cache.PutNEP17Balances(fromAddr, balances); err != nil { @@ -968,11 +969,12 @@ func (bc *Blockchain) processNEP17Transfer(cache *dao.Cached, h util.Uint256, b balances.Trackers[id] = bs transfer.Amount = *amount - isBig, err := cache.AppendNEP17Transfer(toAddr, balances.NextTransferBatch, transfer) + balances.NewBatch, err = cache.AppendNEP17Transfer(toAddr, + balances.NextTransferBatch, balances.NewBatch, transfer) if err != nil { return } - if isBig { + if balances.NewBatch { balances.NextTransferBatch++ } if err := cache.PutNEP17Balances(toAddr, balances); err != nil { diff --git a/pkg/core/dao/cacheddao.go b/pkg/core/dao/cacheddao.go index 32e5019b5..f23bb61df 100644 --- a/pkg/core/dao/cacheddao.go +++ b/pkg/core/dao/cacheddao.go @@ -61,10 +61,16 @@ func (cd *Cached) PutNEP17TransferLog(acc util.Uint160, index uint32, bs *state. } // AppendNEP17Transfer appends new transfer to a transfer event log. -func (cd *Cached) AppendNEP17Transfer(acc util.Uint160, index uint32, tr *state.NEP17Transfer) (bool, error) { - lg, err := cd.GetNEP17TransferLog(acc, index) - if err != nil { - return false, err +func (cd *Cached) AppendNEP17Transfer(acc util.Uint160, index uint32, isNew bool, tr *state.NEP17Transfer) (bool, error) { + var lg *state.NEP17TransferLog + if isNew { + lg = new(state.NEP17TransferLog) + } else { + var err error + lg, err = cd.GetNEP17TransferLog(acc, index) + if err != nil { + return false, err + } } if err := lg.Append(tr); err != nil { return false, err diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index a0c9480c9..b7ad486a5 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -31,7 +31,7 @@ var ( // DAO is a data access object. type DAO interface { AppendAppExecResult(aer *state.AppExecResult, buf *io.BufBinWriter) error - AppendNEP17Transfer(acc util.Uint160, index uint32, tr *state.NEP17Transfer) (bool, error) + AppendNEP17Transfer(acc util.Uint160, index uint32, isNew bool, tr *state.NEP17Transfer) (bool, error) DeleteBlock(h util.Uint256, buf *io.BufBinWriter) error DeleteContractID(id int32) error DeleteStorageItem(id int32, key []byte) error @@ -205,13 +205,16 @@ func (dao *Simple) PutNEP17TransferLog(acc util.Uint160, index uint32, lg *state // AppendNEP17Transfer appends a single NEP17 transfer to a log. // First return value signalizes that log size has exceeded batch size. -func (dao *Simple) AppendNEP17Transfer(acc util.Uint160, index uint32, tr *state.NEP17Transfer) (bool, error) { - lg, err := dao.GetNEP17TransferLog(acc, index) - if err != nil { - if err != storage.ErrKeyNotFound { +func (dao *Simple) AppendNEP17Transfer(acc util.Uint160, index uint32, isNew bool, tr *state.NEP17Transfer) (bool, error) { + var lg *state.NEP17TransferLog + if isNew { + lg = new(state.NEP17TransferLog) + } else { + var err error + lg, err = dao.GetNEP17TransferLog(acc, index) + if err != nil { return false, err } - lg = new(state.NEP17TransferLog) } if err := lg.Append(tr); err != nil { return false, err diff --git a/pkg/core/state/nep17.go b/pkg/core/state/nep17.go index 1e0b7e87f..dd692993a 100644 --- a/pkg/core/state/nep17.go +++ b/pkg/core/state/nep17.go @@ -50,6 +50,8 @@ type NEP17Balances struct { Trackers map[int32]NEP17Tracker // NextTransferBatch stores an index of the next transfer batch. NextTransferBatch uint32 + // NewBatch is true if batch with the `NextTransferBatch` index should be created. + NewBatch bool } // NewNEP17Balances returns new NEP17Balances. @@ -62,6 +64,7 @@ func NewNEP17Balances() *NEP17Balances { // DecodeBinary implements io.Serializable interface. func (bs *NEP17Balances) DecodeBinary(r *io.BinReader) { bs.NextTransferBatch = r.ReadU32LE() + bs.NewBatch = r.ReadBool() lenBalances := r.ReadVarUint() m := make(map[int32]NEP17Tracker, lenBalances) for i := 0; i < int(lenBalances); i++ { @@ -76,6 +79,7 @@ func (bs *NEP17Balances) DecodeBinary(r *io.BinReader) { // EncodeBinary implements io.Serializable interface. func (bs *NEP17Balances) EncodeBinary(w *io.BinWriter) { w.WriteU32LE(bs.NextTransferBatch) + w.WriteBool(bs.NewBatch) w.WriteVarUint(uint64(len(bs.Trackers))) for k, v := range bs.Trackers { w.WriteU32LE(uint32(k))