mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-21 23:29:38 +00:00
Merge pull request #3460 from NeoGoBros/add-onexec-hook
Implement OnExecHook VM API
This commit is contained in:
commit
6f77195ce3
2 changed files with 72 additions and 0 deletions
27
pkg/vm/vm.go
27
pkg/vm/vm.go
|
@ -64,6 +64,15 @@ const (
|
||||||
// SyscallHandler is a type for syscall handler.
|
// SyscallHandler is a type for syscall handler.
|
||||||
type SyscallHandler = func(*VM, uint32) error
|
type SyscallHandler = func(*VM, uint32) error
|
||||||
|
|
||||||
|
// OnExecHook is a type for a callback that is invoked
|
||||||
|
// before each instruction is executed.
|
||||||
|
type OnExecHook = func(scriptHash util.Uint160, offset int, opcode opcode.Opcode)
|
||||||
|
|
||||||
|
// A struct that contains all VM hooks.
|
||||||
|
type hooks struct {
|
||||||
|
onExec OnExecHook
|
||||||
|
}
|
||||||
|
|
||||||
// VM represents the virtual machine.
|
// VM represents the virtual machine.
|
||||||
type VM struct {
|
type VM struct {
|
||||||
state vmstate.State
|
state vmstate.State
|
||||||
|
@ -91,6 +100,9 @@ type VM struct {
|
||||||
|
|
||||||
// invTree is a top-level invocation tree (if enabled).
|
// invTree is a top-level invocation tree (if enabled).
|
||||||
invTree *invocations.Tree
|
invTree *invocations.Tree
|
||||||
|
|
||||||
|
// All registered hooks.
|
||||||
|
hooks hooks
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -117,6 +129,16 @@ func NewWithTrigger(t trigger.Type) *VM {
|
||||||
return vm
|
return vm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetOnExecHook sets the value of OnExecHook which
|
||||||
|
// will be invoked for each executed instruction.
|
||||||
|
// This function panics if the VM has been started.
|
||||||
|
func (v *VM) SetOnExecHook(hook OnExecHook) {
|
||||||
|
if v.state != vmstate.None {
|
||||||
|
panic("Cannot set onExec hook of a started VM")
|
||||||
|
}
|
||||||
|
v.hooks.onExec = hook
|
||||||
|
}
|
||||||
|
|
||||||
// SetPriceGetter registers the given PriceGetterFunc in v.
|
// SetPriceGetter registers the given PriceGetterFunc in v.
|
||||||
// f accepts vm's Context, current instruction and instruction parameter.
|
// f accepts vm's Context, current instruction and instruction parameter.
|
||||||
func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) {
|
func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) {
|
||||||
|
@ -474,7 +496,12 @@ func (v *VM) Step() error {
|
||||||
|
|
||||||
// step executes one instruction in the given context.
|
// step executes one instruction in the given context.
|
||||||
func (v *VM) step(ctx *Context) error {
|
func (v *VM) step(ctx *Context) error {
|
||||||
|
ip := ctx.nextip
|
||||||
|
scriptHash := v.GetCurrentScriptHash()
|
||||||
op, param, err := ctx.Next()
|
op, param, err := ctx.Next()
|
||||||
|
if v.hooks.onExec != nil {
|
||||||
|
v.hooks.onExec(scriptHash, ip, op)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
v.state = vmstate.Fault
|
v.state = vmstate.Fault
|
||||||
return newError(ctx.ip, op, err)
|
return newError(ctx.ip, op, err)
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
|
@ -2776,6 +2777,50 @@ func TestUninitializedSyscallHandler(t *testing.T) {
|
||||||
assert.Equal(t, true, v.HasFailed())
|
assert.Equal(t, true, v.HasFailed())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCannotSetOnExecHookOfStartedVm(t *testing.T) {
|
||||||
|
prog := makeProgram(opcode.NOP)
|
||||||
|
v := load(prog)
|
||||||
|
runVM(t, v)
|
||||||
|
require.Panics(t, func() {
|
||||||
|
v.SetOnExecHook(func(scriptHash util.Uint160, offset int, opcode opcode.Opcode) {})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnExecHookGivesValidTrace(t *testing.T) {
|
||||||
|
prog := makeProgram(opcode.NOP, opcode.NOP, opcode.NOP)
|
||||||
|
expectedOffsets := []int{0, 1, 2, 3}
|
||||||
|
expectedOpcodes := []opcode.Opcode{opcode.NOP, opcode.NOP, opcode.NOP, opcode.RET}
|
||||||
|
|
||||||
|
actualScriptHashes, actualOffsets, actualOpcodes := runWithTrace(t, prog)
|
||||||
|
|
||||||
|
require.Equal(t, expectedOffsets, actualOffsets, "Invalid offsets")
|
||||||
|
require.Equal(t, expectedOpcodes, actualOpcodes, "Invalid opcodes")
|
||||||
|
|
||||||
|
t.Run("Validate collected script hashes", func(t *testing.T) {
|
||||||
|
scriptHash := actualScriptHashes[0]
|
||||||
|
expectedScriptHashes := []util.Uint160{scriptHash, scriptHash, scriptHash, scriptHash}
|
||||||
|
require.Equal(t, expectedScriptHashes, actualScriptHashes)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runWithTrace(t *testing.T, prog []byte) ([]util.Uint160, []int, []opcode.Opcode) {
|
||||||
|
v := load(prog)
|
||||||
|
|
||||||
|
scriptHashes := make([]util.Uint160, 0)
|
||||||
|
offsets := make([]int, 0)
|
||||||
|
opcodes := make([]opcode.Opcode, 0)
|
||||||
|
|
||||||
|
onExec := func(scriptHash util.Uint160, offset int, opcode opcode.Opcode) {
|
||||||
|
scriptHashes = append(scriptHashes, scriptHash)
|
||||||
|
offsets = append(offsets, offset)
|
||||||
|
opcodes = append(opcodes, opcode)
|
||||||
|
}
|
||||||
|
v.SetOnExecHook(onExec)
|
||||||
|
runVM(t, v)
|
||||||
|
|
||||||
|
return scriptHashes, offsets, opcodes
|
||||||
|
}
|
||||||
|
|
||||||
func makeProgram(opcodes ...opcode.Opcode) []byte {
|
func makeProgram(opcodes ...opcode.Opcode) []byte {
|
||||||
prog := make([]byte, len(opcodes)+1) // RET
|
prog := make([]byte, len(opcodes)+1) // RET
|
||||||
for i := 0; i < len(opcodes); i++ {
|
for i := 0; i < len(opcodes); i++ {
|
||||||
|
|
Loading…
Reference in a new issue