diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 6346fd21a..1eaba0ce0 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -52,6 +52,8 @@ type Context struct { retCount int // NEF represents NEF file for the current contract. NEF *nef.File + // invTree is an invocation tree (or branch of it) for this context. + invTree *InvocationTree } var errNoInstParam = errors.New("failed to read instruction parameter") diff --git a/pkg/vm/invocation_tree.go b/pkg/vm/invocation_tree.go new file mode 100644 index 000000000..b4daff18a --- /dev/null +++ b/pkg/vm/invocation_tree.go @@ -0,0 +1,12 @@ +package vm + +import ( + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// InvocationTree represents a tree with script hashes, traversing it +// you can see how contracts called each other. +type InvocationTree struct { + Current util.Uint160 + Calls []*InvocationTree +} diff --git a/pkg/vm/invocation_tree_test.go b/pkg/vm/invocation_tree_test.go new file mode 100644 index 000000000..dc60ba677 --- /dev/null +++ b/pkg/vm/invocation_tree_test.go @@ -0,0 +1,69 @@ +package vm + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/stretchr/testify/require" +) + +func TestInvocationTree(t *testing.T) { + script := []byte{ + byte(opcode.PUSH3), byte(opcode.DEC), + byte(opcode.DUP), byte(opcode.PUSH0), byte(opcode.JMPEQ), (2 + 2 + 2 + 6 + 1), + byte(opcode.CALL), (2 + 2), // CALL shouldn't affect invocation tree. + byte(opcode.JMP), 0xf9, // DEC + byte(opcode.SYSCALL), 0, 0, 0, 0, byte(opcode.DROP), + byte(opcode.RET), + byte(opcode.RET), + byte(opcode.PUSHINT8), 0xff, + } + + cnt := 0 + v := newTestVM() + v.SyscallHandler = func(v *VM, _ uint32) error { + if v.Istack().Len() > 4 { // top -> call -> syscall -> call -> syscall -> ... + v.Estack().PushVal(1) + return nil + } + cnt++ + v.LoadScriptWithHash(script, util.Uint160{byte(cnt)}, 0) + return nil + } + v.EnableInvocationTree() + v.LoadScript(script) + topHash := v.Context().ScriptHash() + require.NoError(t, v.Run()) + + res := &InvocationTree{ + Calls: []*InvocationTree{{ + Current: topHash, + Calls: []*InvocationTree{ + { + Current: util.Uint160{1}, + Calls: []*InvocationTree{ + { + Current: util.Uint160{2}, + }, + { + Current: util.Uint160{3}, + }, + }, + }, + { + Current: util.Uint160{4}, + Calls: []*InvocationTree{ + { + Current: util.Uint160{5}, + }, + { + Current: util.Uint160{6}, + }, + }, + }, + }, + }}, + } + require.Equal(t, res, v.GetInvocationTree()) +} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index dc730495f..960b0e4f7 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -85,6 +85,9 @@ type VM struct { LoadToken func(id int32) error trigger trigger.Type + + // invTree is a top-level invocation tree (if enabled). + invTree *InvocationTree } // New returns a new VM object ready to load AVM bytecode scripts. @@ -240,6 +243,16 @@ func (v *VM) LoadFileWithFlags(path string, f callflag.CallFlag) error { return nil } +// CollectInvocationTree enables collecting invocation tree data. +func (v *VM) EnableInvocationTree() { + v.invTree = &InvocationTree{} +} + +// GetInvocationTree returns current invocation tree structure. +func (v *VM) GetInvocationTree() *InvocationTree { + return v.invTree +} + // Load initializes the VM with the program given. func (v *VM) Load(prog []byte) { v.LoadWithFlags(prog, callflag.NoneFlag) @@ -252,6 +265,7 @@ func (v *VM) LoadWithFlags(prog []byte, f callflag.CallFlag) { v.estack.Clear() v.state = NoneState v.gasConsumed = 0 + v.invTree = nil v.LoadScriptWithFlags(prog, f) } @@ -306,6 +320,16 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint ctx.scriptHash = hash ctx.callingScriptHash = caller ctx.NEF = exe + if v.invTree != nil { + curTree := v.invTree + parent := v.Context() + if parent != nil { + curTree = parent.invTree + } + newTree := &InvocationTree{Current: ctx.ScriptHash()} + curTree.Calls = append(curTree.Calls, newTree) + ctx.invTree = newTree + } v.istack.PushItem(ctx) }