Merge pull request #1455 from nspcc-dev/get_invocation_counter

core: remove error from runtime.GetInvocationCounter
This commit is contained in:
Roman Khimov 2020-10-08 16:34:40 +03:00 committed by GitHub
commit 124ce9d247
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 26 additions and 11 deletions

View file

@ -35,7 +35,6 @@ type Context struct {
DAO *dao.Cached DAO *dao.Cached
Notifications []state.NotificationEvent Notifications []state.NotificationEvent
Log *zap.Logger Log *zap.Logger
Invocations map[util.Uint160]int
VM *vm.VM VM *vm.VM
Functions [][]Function Functions [][]Function
} }
@ -53,7 +52,6 @@ func NewContext(trigger trigger.Type, bc blockchainer.Blockchainer, d dao.DAO, n
DAO: dao, DAO: dao,
Notifications: nes, Notifications: nes,
Log: log, Log: log,
Invocations: make(map[util.Uint160]int),
// Functions is a slice of slices of interops sorted by ID. // Functions is a slice of slices of interops sorted by ID.
Functions: [][]Function{}, Functions: [][]Function{},
} }

View file

@ -67,7 +67,7 @@ func CallExInternal(ic *interop.Context, cs *state.Contract,
} }
u := cs.ScriptHash() u := cs.ScriptHash()
ic.Invocations[u]++ ic.VM.Invocations[u]++
ic.VM.LoadScriptWithHash(cs.Script, u, ic.VM.Context().GetCallFlags()&f) ic.VM.LoadScriptWithHash(cs.Script, u, ic.VM.Context().GetCallFlags()&f)
var isNative bool var isNative bool
for i := range ic.Natives { for i := range ic.Natives {

View file

@ -58,9 +58,11 @@ func GetNotifications(ic *interop.Context) error {
// GetInvocationCounter returns how many times current contract was invoked during current tx execution. // GetInvocationCounter returns how many times current contract was invoked during current tx execution.
func GetInvocationCounter(ic *interop.Context) error { func GetInvocationCounter(ic *interop.Context) error {
count, ok := ic.Invocations[ic.VM.GetCurrentScriptHash()] currentScriptHash := ic.VM.GetCurrentScriptHash()
count, ok := ic.VM.Invocations[currentScriptHash]
if !ok { if !ok {
return errors.New("current contract wasn't invoked from others") count = 1
ic.VM.Invocations[currentScriptHash] = count
} }
ic.VM.Estack().PushVal(count) ic.VM.Estack().PushVal(count)
return nil return nil

View file

@ -296,11 +296,13 @@ func TestRuntimeGetInvocationCounter(t *testing.T) {
v, ic, chain := createVM(t) v, ic, chain := createVM(t)
defer chain.Close() defer chain.Close()
ic.Invocations[hash.Hash160([]byte{2})] = 42 ic.VM.Invocations[hash.Hash160([]byte{2})] = 42
t.Run("Zero", func(t *testing.T) { t.Run("No invocations", func(t *testing.T) {
v.LoadScript([]byte{1}) v.LoadScript([]byte{1})
require.Error(t, runtime.GetInvocationCounter(ic)) // do not return an error in this case.
require.NoError(t, runtime.GetInvocationCounter(ic))
require.EqualValues(t, 1, v.Estack().Pop().BigInt().Int64())
}) })
t.Run("NonZero", func(t *testing.T) { t.Run("NonZero", func(t *testing.T) {
v.LoadScript([]byte{2}) v.LoadScript([]byte{2})

View file

@ -80,6 +80,9 @@ type VM struct {
SyscallHandler func(v *VM, id uint32) error SyscallHandler func(v *VM, id uint32) error
trigger trigger.Type trigger trigger.Type
// Invocations is a script invocation counter.
Invocations map[util.Uint160]int
} }
// New returns a new VM object ready to load AVM bytecode scripts. // New returns a new VM object ready to load AVM bytecode scripts.
@ -96,6 +99,7 @@ func NewWithTrigger(t trigger.Type) *VM {
trigger: t, trigger: t,
SyscallHandler: defaultSyscallHandler, SyscallHandler: defaultSyscallHandler,
Invocations: make(map[util.Uint160]int),
} }
vm.estack = vm.newItemStack("evaluation") vm.estack = vm.newItemStack("evaluation")
@ -1225,7 +1229,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
v.checkInvocationStackSize() v.checkInvocationStackSize()
// Note: jump offset must be calculated regarding to new context, // Note: jump offset must be calculated regarding to new context,
// but it is cloned and thus has the same script and instruction pointer. // but it is cloned and thus has the same script and instruction pointer.
v.Call(ctx, v.getJumpOffset(ctx, parameter)) v.call(ctx, v.getJumpOffset(ctx, parameter))
case opcode.CALLA: case opcode.CALLA:
ptr := v.estack.Pop().Item().(*stackitem.Pointer) ptr := v.estack.Pop().Item().(*stackitem.Pointer)
@ -1233,7 +1237,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
panic("invalid script in pointer") panic("invalid script in pointer")
} }
v.Call(ctx, ptr.Position()) v.call(ctx, ptr.Position())
case opcode.SYSCALL: case opcode.SYSCALL:
interopID := GetInteropID(parameter) interopID := GetInteropID(parameter)
@ -1455,8 +1459,17 @@ func (v *VM) Jump(ctx *Context, offset int) {
} }
// Call calls method by offset. It is similar to Jump but also // Call calls method by offset. It is similar to Jump but also
// pushes new context to the invocation state // pushes new context to the invocation stack and increments
// invocation counter for the corresponding context script hash.
func (v *VM) Call(ctx *Context, offset int) { func (v *VM) Call(ctx *Context, offset int) {
v.call(ctx, offset)
v.Invocations[ctx.ScriptHash()]++
}
// call is an internal representation of Call, which does not
// affect the invocation counter and is only being used by vm
// package.
func (v *VM) call(ctx *Context, offset int) {
newCtx := ctx.Copy() newCtx := ctx.Copy()
newCtx.CheckReturn = false newCtx.CheckReturn = false
newCtx.local = nil newCtx.local = nil