diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index c48d993ec..f289bc15a 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -63,6 +63,15 @@ const ( // SyscallHandler is a type for syscall handler. type SyscallHandler = func(*VM, uint32) error +// OnExecHook is a type for a callback that is invoked +// before each instruction is executed. +type OnExecHook = func(scriptHash util.Uint160, offset int, opcode opcode.Opcode) + +// A struct that contains all VM hooks. +type hooks struct { + onExec OnExecHook +} + // VM represents the virtual machine. type VM struct { state vmstate.State @@ -90,6 +99,9 @@ type VM struct { // invTree is a top-level invocation tree (if enabled). invTree *invocations.Tree + + // All registered hooks. + hooks hooks } var ( @@ -116,6 +128,16 @@ func NewWithTrigger(t trigger.Type) *VM { return vm } +// SetOnExecHook sets the value of OnExecHook which +// will be invoked for each executed instruction. +// This function panics if the VM has been started. +func (v *VM) SetOnExecHook(hook OnExecHook) { + if v.state != vmstate.None { + panic("Cannot set onExec hook of a started VM") + } + v.hooks.onExec = hook +} + // SetPriceGetter registers the given PriceGetterFunc in v. // f accepts vm's Context, current instruction and instruction parameter. func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) { @@ -472,7 +494,12 @@ func (v *VM) Step() error { // step executes one instruction in the given context. func (v *VM) step(ctx *Context) error { + ip := ctx.nextip + scriptHash := v.GetCurrentScriptHash() op, param, err := ctx.Next() + if v.hooks.onExec != nil { + v.hooks.onExec(scriptHash, ip, op) + } if err != nil { v.state = vmstate.Fault return newError(ctx.ip, op, err) diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 2107f749b..e82b78eca 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -18,6 +18,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -2776,6 +2777,50 @@ func TestUninitializedSyscallHandler(t *testing.T) { assert.Equal(t, true, v.HasFailed()) } +func TestCannotSetOnExecHookOfStartedVm(t *testing.T) { + prog := makeProgram(opcode.NOP) + v := load(prog) + runVM(t, v) + require.Panics(t, func() { + v.SetOnExecHook(func(scriptHash util.Uint160, offset int, opcode opcode.Opcode) {}) + }) +} + +func TestOnExecHookGivesValidTrace(t *testing.T) { + prog := makeProgram(opcode.NOP, opcode.NOP, opcode.NOP) + expectedOffsets := []int{0, 1, 2, 3} + expectedOpcodes := []opcode.Opcode{opcode.NOP, opcode.NOP, opcode.NOP, opcode.RET} + + actualScriptHashes, actualOffsets, actualOpcodes := runWithTrace(t, prog) + + require.Equal(t, expectedOffsets, actualOffsets, "Invalid offsets") + require.Equal(t, expectedOpcodes, actualOpcodes, "Invalid opcodes") + + t.Run("Validate collected script hashes", func(t *testing.T) { + scriptHash := actualScriptHashes[0] + expectedScriptHashes := []util.Uint160{scriptHash, scriptHash, scriptHash, scriptHash} + require.Equal(t, expectedScriptHashes, actualScriptHashes) + }) +} + +func runWithTrace(t *testing.T, prog []byte) ([]util.Uint160, []int, []opcode.Opcode) { + v := load(prog) + + scriptHashes := make([]util.Uint160, 0) + offsets := make([]int, 0) + opcodes := make([]opcode.Opcode, 0) + + onExec := func(scriptHash util.Uint160, offset int, opcode opcode.Opcode) { + scriptHashes = append(scriptHashes, scriptHash) + offsets = append(offsets, offset) + opcodes = append(opcodes, opcode) + } + v.SetOnExecHook(onExec) + runVM(t, v) + + return scriptHashes, offsets, opcodes +} + func makeProgram(opcodes ...opcode.Opcode) []byte { prog := make([]byte, len(opcodes)+1) // RET for i := 0; i < len(opcodes); i++ {