Merge pull request #2812 from nspcc-dev/improve-vm-context-handling

Improve vm istack/estack handling
This commit is contained in:
Roman Khimov 2022-11-20 19:42:35 +07:00 committed by GitHub
commit 48140320db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 46 additions and 90 deletions

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/compiler" "github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -149,8 +150,7 @@ func TestAssignments(t *testing.T) {
for i, tc := range assignTestCases { for i, tc := range assignTestCases {
v := vm.New() v := vm.New()
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
v.Istack().Clear() v.Reset(trigger.Application)
v.Estack().Clear()
invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di) invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di)
runAndCheck(t, v, tc.result) runAndCheck(t, v, tc.result)
}) })

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/compiler" "github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -238,8 +239,7 @@ func TestBinaryExprs(t *testing.T) {
for i, tc := range binaryExprTestCases { for i, tc := range binaryExprTestCases {
v := vm.New() v := vm.New()
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
v.Istack().Clear() v.Reset(trigger.Application)
v.Estack().Clear()
invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di) invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di)
runAndCheck(t, v, tc.result) runAndCheck(t, v, tc.result)
}) })

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/compiler" "github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -72,8 +73,7 @@ func TestConvert(t *testing.T) {
for i, tc := range convertTestCases { for i, tc := range convertTestCases {
v := vm.New() v := vm.New()
t.Run(tc.argValue+getFunctionName(tc.returnType), func(t *testing.T) { t.Run(tc.argValue+getFunctionName(tc.returnType), func(t *testing.T) {
v.Istack().Clear() v.Reset(trigger.Application)
v.Estack().Clear()
invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di) invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di)
runAndCheck(t, v, tc.result) runAndCheck(t, v, tc.result)
}) })

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/compiler" "github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -727,8 +728,7 @@ func TestForLoop(t *testing.T) {
for i, tc := range forLoopTestCases { for i, tc := range forLoopTestCases {
v := vm.New() v := vm.New()
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
v.Istack().Clear() v.Reset(trigger.Application)
v.Estack().Clear()
invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di) invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di)
runAndCheck(t, v, tc.result) runAndCheck(t, v, tc.result)
}) })
@ -785,8 +785,7 @@ func TestForLoopComplexConditions(t *testing.T) {
name = tc.Assign name = tc.Assign
} }
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
v.Istack().Clear() v.Reset(trigger.Application)
v.Estack().Clear()
invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di) invokeMethod(t, fmt.Sprintf("F%d", i), ne.Script, v, di)
runAndCheck(t, v, big.NewInt(tc.Result)) runAndCheck(t, v, big.NewInt(tc.Result))
}) })

View file

@ -84,7 +84,7 @@ func evalWithArgs(t *testing.T, src string, op []byte, args []stackitem.Item, re
func assertResult(t *testing.T, vm *vm.VM, result interface{}) { func assertResult(t *testing.T, vm *vm.VM, result interface{}) {
assert.Equal(t, result, vm.PopResult()) assert.Equal(t, result, vm.PopResult())
assert.Equal(t, 0, vm.Istack().Len()) assert.Nil(t, vm.Context())
} }
func vmAndCompile(t *testing.T, src string) *vm.VM { func vmAndCompile(t *testing.T, src string) *vm.VM {

View file

@ -163,12 +163,12 @@ var ErrNativeCall = errors.New("failed native call")
// CallFromNative performs synchronous call from native contract. // CallFromNative performs synchronous call from native contract.
func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error { func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error {
startSize := ic.VM.Istack().Len() startSize := len(ic.VM.Istack())
if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, false, true); err != nil { if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, false, true); err != nil {
return err return err
} }
for !ic.VM.HasStopped() && ic.VM.Istack().Len() > startSize { for !ic.VM.HasStopped() && len(ic.VM.Istack()) > startSize {
if err := ic.VM.Step(); err != nil { if err := ic.VM.Step(); err != nil {
return fmt.Errorf("%w: %v", ErrNativeCall, err) return fmt.Errorf("%w: %v", ErrNativeCall, err)
} }

View file

@ -41,7 +41,7 @@ func GetCallingScriptHash(ic *interop.Context) error {
// GetEntryScriptHash returns entry script hash. // GetEntryScriptHash returns entry script hash.
func GetEntryScriptHash(ic *interop.Context) error { func GetEntryScriptHash(ic *interop.Context) error {
return ic.VM.PushContextScriptHash(ic.VM.Istack().Len() - 1) return ic.VM.PushContextScriptHash(len(ic.VM.Istack()) - 1)
} }
// GetScriptContainer returns transaction or block that contains the script // GetScriptContainer returns transaction or block that contains the script

View file

@ -276,7 +276,7 @@ func (o *Oracle) finish(ic *interop.Context, _ []stackitem.Item) stackitem.Item
// FinishInternal processes an oracle response. // FinishInternal processes an oracle response.
func (o *Oracle) FinishInternal(ic *interop.Context) error { func (o *Oracle) FinishInternal(ic *interop.Context) error {
if ic.VM.Istack().Len() != 2 { if len(ic.VM.Istack()) != 2 {
return errors.New("Oracle.finish called from non-entry script") return errors.New("Oracle.finish called from non-entry script")
} }
if ic.Invocations[o.Hash] != 1 { if ic.Invocations[o.Hash] != 1 {

View file

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"math/big"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
@ -251,42 +250,6 @@ func (c *Context) NumOfReturnVals() int {
return c.retCount return c.retCount
} }
// Value implements the stackitem.Item interface.
func (c *Context) Value() interface{} {
return c
}
// Dup implements the stackitem.Item interface.
func (c *Context) Dup() stackitem.Item {
return c
}
// TryBool implements the stackitem.Item interface.
func (c *Context) TryBool() (bool, error) { panic("can't convert Context to Bool") }
// TryBytes implements the stackitem.Item interface.
func (c *Context) TryBytes() ([]byte, error) {
return nil, errors.New("can't convert Context to ByteArray")
}
// TryInteger implements the stackitem.Item interface.
func (c *Context) TryInteger() (*big.Int, error) {
return nil, errors.New("can't convert Context to Integer")
}
// Type implements the stackitem.Item interface.
func (c *Context) Type() stackitem.Type { panic("Context cannot appear on evaluation stack") }
// Convert implements the stackitem.Item interface.
func (c *Context) Convert(_ stackitem.Type) (stackitem.Item, error) {
panic("Context cannot be converted to anything")
}
// Equals implements the stackitem.Item interface.
func (c *Context) Equals(s stackitem.Item) bool {
return c == s
}
func (c *Context) atBreakPoint() bool { func (c *Context) atBreakPoint() bool {
for _, n := range c.sc.breakPoints { for _, n := range c.sc.breakPoints {
if n == c.nextip { if n == c.nextip {
@ -296,10 +259,6 @@ func (c *Context) atBreakPoint() bool {
return false return false
} }
func (c *Context) String() string {
return "execution context"
}
// IsDeployed returns whether this context contains a deployed contract. // IsDeployed returns whether this context contains a deployed contract.
func (c *Context) IsDeployed() bool { func (c *Context) IsDeployed() bool {
return c.sc.NEF != nil return c.sc.NEF != nil
@ -332,13 +291,10 @@ func dumpSlot(s *slot) string {
// getContextScriptHash returns script hash of the invocation stack element // getContextScriptHash returns script hash of the invocation stack element
// number n. // number n.
func (v *VM) getContextScriptHash(n int) util.Uint160 { func (v *VM) getContextScriptHash(n int) util.Uint160 {
istack := v.Istack() if len(v.istack) <= n {
if istack.Len() <= n {
return util.Uint160{} return util.Uint160{}
} }
element := istack.Peek(n) return v.istack[len(v.istack)-1-n].ScriptHash()
ctx := element.value.(*Context)
return ctx.ScriptHash()
} }
// IsCalledByEntry checks parent script contexts and return true if the current one // IsCalledByEntry checks parent script contexts and return true if the current one

View file

@ -24,7 +24,7 @@ func TestInvocationTree(t *testing.T) {
cnt := 0 cnt := 0
v := newTestVM() v := newTestVM()
v.SyscallHandler = func(v *VM, _ uint32) error { v.SyscallHandler = func(v *VM, _ uint32) error {
if v.Istack().Len() > 4 { // top -> call -> syscall -> call -> syscall -> ... if len(v.Istack()) > 4 { // top -> call -> syscall -> call -> syscall -> ...
v.Estack().PushVal(1) v.Estack().PushVal(1)
return nil return nil
} }

View file

@ -166,7 +166,7 @@ func testFile(t *testing.T, filename string) {
if len(result.InvocationStack) > 0 { if len(result.InvocationStack) > 0 {
for i, s := range result.InvocationStack { for i, s := range result.InvocationStack {
ctx := vm.istack.Peek(i).Value().(*Context) ctx := vm.istack[len(vm.istack)-1-i]
if ctx.nextip < len(ctx.sc.prog) { if ctx.nextip < len(ctx.sc.prog) {
require.Equal(t, s.InstructionPointer, ctx.nextip) require.Equal(t, s.InstructionPointer, ctx.nextip)
op, err := opcode.FromString(s.Instruction) op, err := opcode.FromString(s.Instruction)

View file

@ -69,7 +69,7 @@ type VM struct {
// callback to get interop price // callback to get interop price
getPrice func(opcode.Opcode, []byte) int64 getPrice func(opcode.Opcode, []byte) int64
istack Stack // invocation stack. istack []*Context // invocation stack.
estack *Stack // execution stack. estack *Stack // execution stack.
uncaughtException stackitem.Item // exception being handled uncaughtException stackitem.Item // exception being handled
@ -110,8 +110,7 @@ func NewWithTrigger(t trigger.Type) *VM {
trigger: t, trigger: t,
} }
initStack(&vm.istack, "invocation", nil) vm.istack = make([]*Context, 0, 8) // Most of invocations use one-two contracts, but they're likely to have internal calls.
vm.istack.elems = make([]Element, 0, 8) // Most of invocations use one-two contracts, but they're likely to have internal calls.
vm.estack = newStack("evaluation", &vm.refs) vm.estack = newStack("evaluation", &vm.refs)
return vm return vm
} }
@ -128,7 +127,7 @@ func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) {
func (v *VM) Reset(t trigger.Type) { func (v *VM) Reset(t trigger.Type) {
v.state = vmstate.None v.state = vmstate.None
v.getPrice = nil v.getPrice = nil
v.istack.elems = v.istack.elems[:0] v.istack = v.istack[:0]
v.estack.elems = v.estack.elems[:0] v.estack.elems = v.estack.elems[:0]
v.uncaughtException = nil v.uncaughtException = nil
v.refs = 0 v.refs = 0
@ -157,8 +156,8 @@ func (v *VM) Estack() *Stack {
} }
// Istack returns the invocation stack, so interop hooks can utilize this. // Istack returns the invocation stack, so interop hooks can utilize this.
func (v *VM) Istack() *Stack { func (v *VM) Istack() []*Context {
return &v.istack return v.istack
} }
// PrintOps prints the opcodes of the current loaded program to stdout. // PrintOps prints the opcodes of the current loaded program to stdout.
@ -284,7 +283,7 @@ func (v *VM) Load(prog []byte) {
// LoadWithFlags initializes the VM with the program and flags given. // LoadWithFlags initializes the VM with the program and flags given.
func (v *VM) LoadWithFlags(prog []byte, f callflag.CallFlag) { func (v *VM) LoadWithFlags(prog []byte, f callflag.CallFlag) {
// Clear all stacks and state, it could be a reload. // Clear all stacks and state, it could be a reload.
v.istack.Clear() v.istack = v.istack[:0]
v.estack.Clear() v.estack.Clear()
v.state = vmstate.None v.state = vmstate.None
v.gasConsumed = 0 v.gasConsumed = 0
@ -359,16 +358,16 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint
ctx.sc.invTree = newTree ctx.sc.invTree = newTree
} }
ctx.sc.onUnload = onContextUnload ctx.sc.onUnload = onContextUnload
v.istack.PushItem(ctx) v.istack = append(v.istack, ctx)
} }
// Context returns the current executed context. Nil if there is no context, // Context returns the current executed context. Nil if there is no context,
// which implies no program is loaded. // which implies no program is loaded.
func (v *VM) Context() *Context { func (v *VM) Context() *Context {
if v.istack.Len() == 0 { if len(v.istack) == 0 {
return nil return nil
} }
return v.istack.Peek(0).value.(*Context) return v.istack[len(v.istack)-1]
} }
// PopResult is used to pop the first item of the evaluation stack. This allows // PopResult is used to pop the first item of the evaluation stack. This allows
@ -382,7 +381,7 @@ func (v *VM) PopResult() interface{} {
// DumpIStack returns json formatted representation of the invocation stack. // DumpIStack returns json formatted representation of the invocation stack.
func (v *VM) DumpIStack() string { func (v *VM) DumpIStack() string {
b, _ := json.MarshalIndent(v.istack.ToArray(), "", " ") b, _ := json.MarshalIndent(v.istack, "", " ")
return string(b) return string(b)
} }
@ -405,7 +404,7 @@ func (v *VM) State() vmstate.State {
// Ready returns true if the VM is ready to execute the loaded program. // Ready returns true if the VM is ready to execute the loaded program.
// It will return false if no program is loaded. // It will return false if no program is loaded.
func (v *VM) Ready() bool { func (v *VM) Ready() bool {
return v.istack.Len() > 0 return len(v.istack) > 0
} }
// Run starts execution of the loaded program. // Run starts execution of the loaded program.
@ -505,8 +504,8 @@ func (v *VM) StepOut() error {
v.state = vmstate.None v.state = vmstate.None
} }
expSize := v.istack.Len() expSize := len(v.istack)
for v.state == vmstate.None && v.istack.Len() >= expSize { for v.state == vmstate.None && len(v.istack) >= expSize {
err = v.StepInto() err = v.StepInto()
} }
if v.state == vmstate.None { if v.state == vmstate.None {
@ -527,10 +526,10 @@ func (v *VM) StepOver() error {
v.state = vmstate.None v.state = vmstate.None
} }
expSize := v.istack.Len() expSize := len(v.istack)
for { for {
err = v.StepInto() err = v.StepInto()
if !(v.state == vmstate.None && v.istack.Len() > expSize) { if !(v.state == vmstate.None && len(v.istack) > expSize) {
break break
} }
} }
@ -1467,11 +1466,12 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
} }
case opcode.RET: case opcode.RET:
oldCtx := v.istack.Pop().value.(*Context) oldCtx := v.istack[len(v.istack)-1]
v.istack = v.istack[:len(v.istack)-1]
oldEstack := v.estack oldEstack := v.estack
v.unloadContext(oldCtx) v.unloadContext(oldCtx)
if v.istack.Len() == 0 { if len(v.istack) == 0 {
v.state = vmstate.Halt v.state = vmstate.Halt
break break
} }
@ -1701,7 +1701,7 @@ func (v *VM) call(ctx *Context, offset int) {
} }
// New context -> new exception handlers. // New context -> new exception handlers.
newCtx.tryStack.elems = ctx.tryStack.elems[len(ctx.tryStack.elems):] newCtx.tryStack.elems = ctx.tryStack.elems[len(ctx.tryStack.elems):]
v.istack.PushItem(newCtx) v.istack = append(v.istack, newCtx)
newCtx.Jump(offset) newCtx.Jump(offset)
} }
@ -1739,9 +1739,8 @@ func calcJumpOffset(ctx *Context, parameter []byte) (int, int, error) {
} }
func (v *VM) handleException() { func (v *VM) handleException() {
for pop := 0; pop < v.istack.Len(); pop++ { for pop := 0; pop < len(v.istack); pop++ {
ictxv := v.istack.Peek(pop) ictx := v.istack[len(v.istack)-1-pop]
ictx := ictxv.value.(*Context)
for j := 0; j < ictx.tryStack.Len(); j++ { for j := 0; j < ictx.tryStack.Len(); j++ {
e := ictx.tryStack.Peek(j) e := ictx.tryStack.Peek(j)
ectx := e.Value().(*exceptionHandlingContext) ectx := e.Value().(*exceptionHandlingContext)
@ -1751,9 +1750,11 @@ func (v *VM) handleException() {
continue continue
} }
for i := 0; i < pop; i++ { for i := 0; i < pop; i++ {
ctx := v.istack.Pop().value.(*Context) ctx := v.istack[len(v.istack)-1]
v.istack = v.istack[:len(v.istack)-1]
v.unloadContext(ctx) v.unloadContext(ctx)
} }
v.estack = ictx.sc.estack
if ectx.State == eTry && ectx.HasCatch() { if ectx.State == eTry && ectx.HasCatch() {
ectx.State = eCatch ectx.State = eCatch
v.estack.PushItem(v.uncaughtException) v.estack.PushItem(v.uncaughtException)
@ -1937,7 +1938,7 @@ func validateMapKey(key Element) {
} }
func (v *VM) checkInvocationStackSize() { func (v *VM) checkInvocationStackSize() {
if v.istack.Len() >= MaxInvocationStackSize { if len(v.istack) >= MaxInvocationStackSize {
panic("invocation stack is too big") panic("invocation stack is too big")
} }
} }
@ -1959,7 +1960,7 @@ func (v *VM) GetCallingScriptHash() util.Uint160 {
// GetEntryScriptHash implements the ScriptHashGetter interface. // GetEntryScriptHash implements the ScriptHashGetter interface.
func (v *VM) GetEntryScriptHash() util.Uint160 { func (v *VM) GetEntryScriptHash() util.Uint160 {
return v.getContextScriptHash(v.istack.Len() - 1) return v.getContextScriptHash(len(v.istack) - 1)
} }
// GetCurrentScriptHash implements the ScriptHashGetter interface. // GetCurrentScriptHash implements the ScriptHashGetter interface.

View file

@ -124,7 +124,7 @@ func TestPushBytes1to75(t *testing.T) {
errExec := vm.execute(nil, opcode.RET, nil) errExec := vm.execute(nil, opcode.RET, nil)
require.NoError(t, errExec) require.NoError(t, errExec)
assert.Equal(t, 0, vm.istack.Len()) assert.Nil(t, vm.Context())
buf.Reset() buf.Reset()
} }
} }