vm: extract shared parts of the Context

Local calls reuse them, cross-contract calls create new ones. This allows to
avoid some allocations and use a little less memory.
This commit is contained in:
Roman Khimov 2022-08-04 16:15:51 +03:00
parent e5c59f8ddd
commit 13f5fdbe8a
7 changed files with 112 additions and 108 deletions

View file

@ -27,7 +27,7 @@ func LoadToken(ic *interop.Context, id int32) error {
if !ctx.GetCallFlags().Has(callflag.ReadStates | callflag.AllowCall) { if !ctx.GetCallFlags().Has(callflag.ReadStates | callflag.AllowCall) {
return errors.New("invalid call flags") return errors.New("invalid call flags")
} }
tok := ctx.NEF.Tokens[id] tok := ctx.GetNEF().Tokens[id]
if int(tok.ParamCount) > ctx.Estack().Len() { if int(tok.ParamCount) > ctx.Estack().Len() {
return errors.New("stack is too small") return errors.New("stack is too small")
} }

View file

@ -73,7 +73,7 @@ func Notify(ic *interop.Context) error {
if len(name) > MaxEventNameLen { if len(name) > MaxEventNameLen {
return fmt.Errorf("event name must be less than %d", MaxEventNameLen) return fmt.Errorf("event name must be less than %d", MaxEventNameLen)
} }
if ic.VM.Context().NEF == nil { if !ic.VM.Context().IsDeployed() {
return errors.New("notifications are not allowed in dynamic scripts") return errors.New("notifications are not allowed in dynamic scripts")
} }

View file

@ -16,14 +16,10 @@ import (
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
// Context represents the current execution context of the VM. // scriptContext is a part of the Context that is shared between multiple Contexts,
type Context struct { // it's created when a new script is loaded into the VM while regular
// Instruction pointer. // CALL/CALLL/CALLA internal invocations reuse it.
ip int type scriptContext struct {
// The next instruction pointer.
nextip int
// The raw program script. // The raw program script.
prog []byte prog []byte
@ -33,12 +29,7 @@ type Context struct {
// Evaluation stack pointer. // Evaluation stack pointer.
estack *Stack estack *Stack
static *slot static slot
local slot
arguments slot
// Exception context stack.
tryStack Stack
// Script hash of the prog. // Script hash of the prog.
scriptHash util.Uint160 scriptHash util.Uint160
@ -49,17 +40,35 @@ type Context struct {
// Call flags this context was created with. // Call flags this context was created with.
callFlag callflag.CallFlag callFlag callflag.CallFlag
// retCount specifies the number of return values.
retCount int
// NEF represents a NEF file for the current contract. // NEF represents a NEF file for the current contract.
NEF *nef.File NEF *nef.File
// invTree is an invocation tree (or branch of it) for this context. // invTree is an invocation tree (or a branch of it) for this context.
invTree *invocations.Tree invTree *invocations.Tree
// onUnload is a callback that should be called after current context unloading // onUnload is a callback that should be called after current context unloading
// if no exception occurs. // if no exception occurs.
onUnload ContextUnloadCallback onUnload ContextUnloadCallback
} }
// Context represents the current execution context of the VM.
type Context struct {
// Instruction pointer.
ip int
// The next instruction pointer.
nextip int
sc *scriptContext
local slot
arguments slot
// Exception context stack.
tryStack Stack
// retCount specifies the number of return values.
retCount int
}
// ContextUnloadCallback is a callback method used on context unloading from istack. // ContextUnloadCallback is a callback method used on context unloading from istack.
type ContextUnloadCallback func(commit bool) error type ContextUnloadCallback func(commit bool) error
@ -74,7 +83,9 @@ func NewContext(b []byte) *Context {
// return value count and initial position in script. // return value count and initial position in script.
func NewContextWithParams(b []byte, rvcount int, pos int) *Context { func NewContextWithParams(b []byte, rvcount int, pos int) *Context {
return &Context{ return &Context{
sc: &scriptContext{
prog: b, prog: b,
},
retCount: rvcount, retCount: rvcount,
nextip: pos, nextip: pos,
} }
@ -82,7 +93,7 @@ func NewContextWithParams(b []byte, rvcount int, pos int) *Context {
// Estack returns the evaluation stack of c. // Estack returns the evaluation stack of c.
func (c *Context) Estack() *Stack { func (c *Context) Estack() *Stack {
return c.estack return c.sc.estack
} }
// NextIP returns the next instruction pointer. // NextIP returns the next instruction pointer.
@ -92,7 +103,7 @@ func (c *Context) NextIP() int {
// Jump unconditionally moves the next instruction pointer to the specified location. // Jump unconditionally moves the next instruction pointer to the specified location.
func (c *Context) Jump(pos int) { func (c *Context) Jump(pos int) {
if pos < 0 || pos >= len(c.prog) { if pos < 0 || pos >= len(c.sc.prog) {
panic("instruction offset is out of range") panic("instruction offset is out of range")
} }
c.nextip = pos c.nextip = pos
@ -105,11 +116,12 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) {
var err error var err error
c.ip = c.nextip c.ip = c.nextip
if c.ip >= len(c.prog) { prog := c.sc.prog
if c.ip >= len(prog) {
return opcode.RET, nil, nil return opcode.RET, nil, nil
} }
var instrbyte = c.prog[c.ip] var instrbyte = prog[c.ip]
instr := opcode.Opcode(instrbyte) instr := opcode.Opcode(instrbyte)
if !opcode.IsValid(instr) { if !opcode.IsValid(instr) {
return instr, nil, fmt.Errorf("incorrect opcode %s", instr.String()) return instr, nil, fmt.Errorf("incorrect opcode %s", instr.String())
@ -119,24 +131,24 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) {
var numtoread int var numtoread int
switch instr { switch instr {
case opcode.PUSHDATA1: case opcode.PUSHDATA1:
if c.nextip >= len(c.prog) { if c.nextip >= len(prog) {
err = errNoInstParam err = errNoInstParam
} else { } else {
numtoread = int(c.prog[c.nextip]) numtoread = int(prog[c.nextip])
c.nextip++ c.nextip++
} }
case opcode.PUSHDATA2: case opcode.PUSHDATA2:
if c.nextip+1 >= len(c.prog) { if c.nextip+1 >= len(prog) {
err = errNoInstParam err = errNoInstParam
} else { } else {
numtoread = int(binary.LittleEndian.Uint16(c.prog[c.nextip : c.nextip+2])) numtoread = int(binary.LittleEndian.Uint16(prog[c.nextip : c.nextip+2]))
c.nextip += 2 c.nextip += 2
} }
case opcode.PUSHDATA4: case opcode.PUSHDATA4:
if c.nextip+3 >= len(c.prog) { if c.nextip+3 >= len(prog) {
err = errNoInstParam err = errNoInstParam
} else { } else {
var n = binary.LittleEndian.Uint32(c.prog[c.nextip : c.nextip+4]) var n = binary.LittleEndian.Uint32(prog[c.nextip : c.nextip+4])
if n > stackitem.MaxSize { if n > stackitem.MaxSize {
return instr, nil, errors.New("parameter is too big") return instr, nil, errors.New("parameter is too big")
} }
@ -166,13 +178,13 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) {
return instr, nil, nil return instr, nil, nil
} }
} }
if c.nextip+numtoread-1 >= len(c.prog) { if c.nextip+numtoread-1 >= len(prog) {
err = errNoInstParam err = errNoInstParam
} }
if err != nil { if err != nil {
return instr, nil, err return instr, nil, err
} }
parameter := c.prog[c.nextip : c.nextip+numtoread] parameter := prog[c.nextip : c.nextip+numtoread]
c.nextip += numtoread c.nextip += numtoread
return instr, parameter, nil return instr, parameter, nil
} }
@ -184,46 +196,44 @@ func (c *Context) IP() int {
// LenInstr returns the number of instructions loaded. // LenInstr returns the number of instructions loaded.
func (c *Context) LenInstr() int { func (c *Context) LenInstr() int {
return len(c.prog) return len(c.sc.prog)
} }
// CurrInstr returns the current instruction and opcode. // CurrInstr returns the current instruction and opcode.
func (c *Context) CurrInstr() (int, opcode.Opcode) { func (c *Context) CurrInstr() (int, opcode.Opcode) {
return c.ip, opcode.Opcode(c.prog[c.ip]) return c.ip, opcode.Opcode(c.sc.prog[c.ip])
} }
// NextInstr returns the next instruction and opcode. // NextInstr returns the next instruction and opcode.
func (c *Context) NextInstr() (int, opcode.Opcode) { func (c *Context) NextInstr() (int, opcode.Opcode) {
op := opcode.RET op := opcode.RET
if c.nextip < len(c.prog) { if c.nextip < len(c.sc.prog) {
op = opcode.Opcode(c.prog[c.nextip]) op = opcode.Opcode(c.sc.prog[c.nextip])
} }
return c.nextip, op return c.nextip, op
} }
// Copy returns an new exact copy of c.
func (c *Context) Copy() *Context {
ctx := new(Context)
*ctx = *c
return ctx
}
// GetCallFlags returns the calling flags which the context was created with. // GetCallFlags returns the calling flags which the context was created with.
func (c *Context) GetCallFlags() callflag.CallFlag { func (c *Context) GetCallFlags() callflag.CallFlag {
return c.callFlag return c.sc.callFlag
} }
// Program returns the loaded program. // Program returns the loaded program.
func (c *Context) Program() []byte { func (c *Context) Program() []byte {
return c.prog return c.sc.prog
} }
// ScriptHash returns a hash of the script in the current context. // ScriptHash returns a hash of the script in the current context.
func (c *Context) ScriptHash() util.Uint160 { func (c *Context) ScriptHash() util.Uint160 {
if c.scriptHash.Equals(util.Uint160{}) { if c.sc.scriptHash.Equals(util.Uint160{}) {
c.scriptHash = hash.Hash160(c.prog) c.sc.scriptHash = hash.Hash160(c.sc.prog)
} }
return c.scriptHash return c.sc.scriptHash
}
// GetNEF returns NEF structure used by this context if it's present.
func (c *Context) GetNEF() *nef.File {
return c.sc.NEF
} }
// Value implements the stackitem.Item interface. // Value implements the stackitem.Item interface.
@ -263,7 +273,7 @@ func (c *Context) Equals(s stackitem.Item) bool {
} }
func (c *Context) atBreakPoint() bool { func (c *Context) atBreakPoint() bool {
for _, n := range c.breakPoints { for _, n := range c.sc.breakPoints {
if n == c.nextip { if n == c.nextip {
return true return true
} }
@ -277,12 +287,12 @@ func (c *Context) String() string {
// IsDeployed returns whether this context contains a deployed contract. // IsDeployed returns whether this context contains a deployed contract.
func (c *Context) IsDeployed() bool { func (c *Context) IsDeployed() bool {
return c.NEF != nil return c.sc.NEF != nil
} }
// DumpStaticSlot returns json formatted representation of the given slot. // DumpStaticSlot returns json formatted representation of the given slot.
func (c *Context) DumpStaticSlot() string { func (c *Context) DumpStaticSlot() string {
return dumpSlot(c.static) return dumpSlot(&c.sc.static)
} }
// DumpLocalSlot returns json formatted representation of the given slot. // DumpLocalSlot returns json formatted representation of the given slot.

View file

@ -36,8 +36,9 @@ func defaultSyscallHandler(v *VM, id uint32) error {
return errors.New("syscall not found") return errors.New("syscall not found")
} }
d := defaultVMInterops[n] d := defaultVMInterops[n]
if !v.Context().callFlag.Has(d.RequiredFlags) { ctxFlag := v.Context().sc.callFlag
return fmt.Errorf("missing call flags: %05b vs %05b", v.Context().callFlag, d.RequiredFlags) if !ctxFlag.Has(d.RequiredFlags) {
return fmt.Errorf("missing call flags: %05b vs %05b", ctxFlag, d.RequiredFlags)
} }
return d.Func(v) return d.Func(v)
} }

View file

@ -115,7 +115,7 @@ func testSyscallHandler(v *VM, id uint32) error {
case 0x77777777: case 0x77777777:
v.Estack().PushVal(stackitem.NewInterop(new(int))) v.Estack().PushVal(stackitem.NewInterop(new(int)))
case 0x66666666: case 0x66666666:
if !v.Context().callFlag.Has(callflag.ReadOnly) { if !v.Context().sc.callFlag.Has(callflag.ReadOnly) {
return errors.New("invalid call flags") return errors.New("invalid call flags")
} }
v.Estack().PushVal(stackitem.NewInterop(new(int))) v.Estack().PushVal(stackitem.NewInterop(new(int)))
@ -167,14 +167,14 @@ func testFile(t *testing.T, filename string) {
if len(result.InvocationStack) > 0 { if len(result.InvocationStack) > 0 {
for i, s := range result.InvocationStack { for i, s := range result.InvocationStack {
ctx := vm.istack.Peek(i).Value().(*Context) ctx := vm.istack.Peek(i).Value().(*Context)
if ctx.nextip < len(ctx.prog) { if ctx.nextip < len(ctx.sc.prog) {
require.Equal(t, s.InstructionPointer, ctx.nextip) require.Equal(t, s.InstructionPointer, ctx.nextip)
op, err := opcode.FromString(s.Instruction) op, err := opcode.FromString(s.Instruction)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, op, opcode.Opcode(ctx.prog[ctx.nextip])) require.Equal(t, op, opcode.Opcode(ctx.sc.prog[ctx.nextip]))
} }
compareStacks(t, s.EStack, vm.estack) compareStacks(t, s.EStack, vm.estack)
compareSlots(t, s.StaticFields, ctx.static) compareSlots(t, s.StaticFields, ctx.sc.static)
} }
} }
@ -240,8 +240,8 @@ func compareStacks(t *testing.T, expected []vmUTStackItem, actual *Stack) {
compareItemArrays(t, expected, actual.Len(), func(i int) stackitem.Item { return actual.Peek(i).Item() }) compareItemArrays(t, expected, actual.Len(), func(i int) stackitem.Item { return actual.Peek(i).Item() })
} }
func compareSlots(t *testing.T, expected []vmUTStackItem, actual *slot) { func compareSlots(t *testing.T, expected []vmUTStackItem, actual slot) {
if (actual == nil || *actual == nil) && len(expected) == 0 { if actual == nil && len(expected) == 0 {
return return
} }
require.NotNil(t, actual) require.NotNil(t, actual)

View file

@ -56,7 +56,7 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int,
return nil return nil
} }
if sslot != 0 { if sslot != 0 {
v.Context().static.init(sslot, &v.refs) v.Context().sc.static.init(sslot, &v.refs)
} }
if slotloc != 0 && slotarg != 0 { if slotloc != 0 && slotarg != 0 {
v.Context().local.init(slotloc, &v.refs) v.Context().local.init(slotloc, &v.refs)

View file

@ -171,9 +171,7 @@ func (v *VM) PrintOps(out io.Writer) {
w := tabwriter.NewWriter(out, 0, 0, 4, ' ', 0) w := tabwriter.NewWriter(out, 0, 0, 4, ' ', 0)
fmt.Fprintln(w, "INDEX\tOPCODE\tPARAMETER") fmt.Fprintln(w, "INDEX\tOPCODE\tPARAMETER")
realctx := v.Context() realctx := v.Context()
ctx := realctx.Copy() ctx := &Context{sc: realctx.sc}
ctx.ip = 0
ctx.nextip = 0
for { for {
cursor := "" cursor := ""
instr, parameter, err := ctx.Next() instr, parameter, err := ctx.Next()
@ -228,7 +226,7 @@ func (v *VM) PrintOps(out io.Writer) {
} }
fmt.Fprintf(w, "%d\t%s\t%s%s\n", ctx.ip, instr, desc, cursor) fmt.Fprintf(w, "%d\t%s\t%s%s\n", ctx.ip, instr, desc, cursor)
if ctx.nextip >= len(ctx.prog) { if ctx.nextip >= len(ctx.sc.prog) {
break break
} }
} }
@ -246,7 +244,7 @@ func getOffsetDesc(ctx *Context, parameter []byte) string {
// AddBreakPoint adds a breakpoint to the current context. // AddBreakPoint adds a breakpoint to the current context.
func (v *VM) AddBreakPoint(n int) { func (v *VM) AddBreakPoint(n int) {
ctx := v.Context() ctx := v.Context()
ctx.breakPoints = append(ctx.breakPoints, n) ctx.sc.breakPoints = append(ctx.sc.breakPoints, n)
} }
// AddBreakPointRel adds a breakpoint relative to the current // AddBreakPointRel adds a breakpoint relative to the current
@ -337,31 +335,28 @@ func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160
// It should be used for calling from native contracts. // It should be used for calling from native contracts.
func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160, func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160,
hash util.Uint160, f callflag.CallFlag, rvcount int, offset int, onContextUnload ContextUnloadCallback) { hash util.Uint160, f callflag.CallFlag, rvcount int, offset int, onContextUnload ContextUnloadCallback) {
var sl slot
v.checkInvocationStackSize() v.checkInvocationStackSize()
ctx := NewContextWithParams(b, rvcount, offset) ctx := NewContextWithParams(b, rvcount, offset)
if rvcount != -1 || v.estack.Len() != 0 { if rvcount != -1 || v.estack.Len() != 0 {
v.estack = subStack(v.estack) v.estack = subStack(v.estack)
} }
ctx.estack = v.estack ctx.sc.estack = v.estack
initStack(&ctx.tryStack, "exception", nil) initStack(&ctx.tryStack, "exception", nil)
ctx.callFlag = f ctx.sc.callFlag = f
ctx.static = &sl ctx.sc.scriptHash = hash
ctx.scriptHash = hash ctx.sc.callingScriptHash = caller
ctx.callingScriptHash = caller ctx.sc.NEF = exe
ctx.NEF = exe
if v.invTree != nil { if v.invTree != nil {
curTree := v.invTree curTree := v.invTree
parent := v.Context() parent := v.Context()
if parent != nil { if parent != nil {
curTree = parent.invTree curTree = parent.sc.invTree
} }
newTree := &invocations.Tree{Current: ctx.ScriptHash()} newTree := &invocations.Tree{Current: ctx.ScriptHash()}
curTree.Calls = append(curTree.Calls, newTree) curTree.Calls = append(curTree.Calls, newTree)
ctx.invTree = newTree ctx.sc.invTree = newTree
} }
ctx.onUnload = onContextUnload ctx.sc.onUnload = onContextUnload
v.istack.PushItem(ctx) v.istack.PushItem(ctx)
} }
@ -481,7 +476,7 @@ func (v *VM) StepInto() error {
return nil return nil
} }
if ctx != nil && ctx.prog != nil { if ctx != nil && ctx.sc.prog != nil {
op, param, err := ctx.Next() op, param, err := ctx.Next()
if err != nil { if err != nil {
v.state = vmstate.Fault v.state = vmstate.Fault
@ -584,7 +579,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
} }
}() }()
if v.getPrice != nil && ctx.ip < len(ctx.prog) { if v.getPrice != nil && ctx.ip < len(ctx.sc.prog) {
v.gasConsumed += v.getPrice(op, parameter) v.gasConsumed += v.getPrice(op, parameter)
if v.GasLimit >= 0 && v.gasConsumed > v.GasLimit { if v.GasLimit >= 0 && v.gasConsumed > v.GasLimit {
panic("gas limit is exceeded") panic("gas limit is exceeded")
@ -610,7 +605,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.PUSHA: case opcode.PUSHA:
n := getJumpOffset(ctx, parameter) n := getJumpOffset(ctx, parameter)
ptr := stackitem.NewPointerWithHash(n, ctx.prog, ctx.ScriptHash()) ptr := stackitem.NewPointerWithHash(n, ctx.sc.prog, ctx.ScriptHash())
v.estack.PushItem(ptr) v.estack.PushItem(ptr)
case opcode.PUSHNULL: case opcode.PUSHNULL:
@ -637,7 +632,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
if parameter[0] == 0 { if parameter[0] == 0 {
panic("zero argument") panic("zero argument")
} }
ctx.static.init(int(parameter[0]), &v.refs) ctx.sc.static.init(int(parameter[0]), &v.refs)
case opcode.INITSLOT: case opcode.INITSLOT:
if ctx.local != nil || ctx.arguments != nil { if ctx.local != nil || ctx.arguments != nil {
@ -658,20 +653,20 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
} }
case opcode.LDSFLD0, opcode.LDSFLD1, opcode.LDSFLD2, opcode.LDSFLD3, opcode.LDSFLD4, opcode.LDSFLD5, opcode.LDSFLD6: case opcode.LDSFLD0, opcode.LDSFLD1, opcode.LDSFLD2, opcode.LDSFLD3, opcode.LDSFLD4, opcode.LDSFLD5, opcode.LDSFLD6:
item := ctx.static.Get(int(op - opcode.LDSFLD0)) item := ctx.sc.static.Get(int(op - opcode.LDSFLD0))
v.estack.PushItem(item) v.estack.PushItem(item)
case opcode.LDSFLD: case opcode.LDSFLD:
item := ctx.static.Get(int(parameter[0])) item := ctx.sc.static.Get(int(parameter[0]))
v.estack.PushItem(item) v.estack.PushItem(item)
case opcode.STSFLD0, opcode.STSFLD1, opcode.STSFLD2, opcode.STSFLD3, opcode.STSFLD4, opcode.STSFLD5, opcode.STSFLD6: case opcode.STSFLD0, opcode.STSFLD1, opcode.STSFLD2, opcode.STSFLD3, opcode.STSFLD4, opcode.STSFLD5, opcode.STSFLD6:
item := v.estack.Pop().Item() item := v.estack.Pop().Item()
ctx.static.Set(int(op-opcode.STSFLD0), item, &v.refs) ctx.sc.static.Set(int(op-opcode.STSFLD0), item, &v.refs)
case opcode.STSFLD: case opcode.STSFLD:
item := v.estack.Pop().Item() item := v.estack.Pop().Item()
ctx.static.Set(int(parameter[0]), item, &v.refs) ctx.sc.static.Set(int(parameter[0]), item, &v.refs)
case opcode.LDLOC0, opcode.LDLOC1, opcode.LDLOC2, opcode.LDLOC3, opcode.LDLOC4, opcode.LDLOC5, opcode.LDLOC6: case opcode.LDLOC0, opcode.LDLOC1, opcode.LDLOC2, opcode.LDLOC3, opcode.LDLOC4, opcode.LDLOC5, opcode.LDLOC6:
item := ctx.local.Get(int(op - opcode.LDLOC0)) item := ctx.local.Get(int(op - opcode.LDLOC0))
@ -1475,7 +1470,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
break break
} }
newEstack := v.Context().estack newEstack := v.Context().sc.estack
if oldEstack != newEstack { if oldEstack != newEstack {
if oldCtx.retCount >= 0 && oldEstack.Len() != oldCtx.retCount { if oldCtx.retCount >= 0 && oldEstack.Len() != oldCtx.retCount {
panic(fmt.Errorf("invalid return values count: expected %d, got %d", panic(fmt.Errorf("invalid return values count: expected %d, got %d",
@ -1631,11 +1626,12 @@ func (v *VM) unloadContext(ctx *Context) {
ctx.arguments.ClearRefs(&v.refs) ctx.arguments.ClearRefs(&v.refs)
} }
currCtx := v.Context() currCtx := v.Context()
if ctx.static != nil && (currCtx == nil || ctx.static != currCtx.static) { if currCtx == nil || ctx.sc != currCtx.sc {
ctx.static.ClearRefs(&v.refs) if ctx.sc.static != nil {
ctx.sc.static.ClearRefs(&v.refs)
} }
if ctx.onUnload != nil { if ctx.sc.onUnload != nil {
err := ctx.onUnload(v.uncaughtException == nil) err := ctx.sc.onUnload(v.uncaughtException == nil)
if err != nil { if err != nil {
errMessage := fmt.Sprintf("context unload callback failed: %s", err) errMessage := fmt.Sprintf("context unload callback failed: %s", err)
if v.uncaughtException != nil { if v.uncaughtException != nil {
@ -1644,6 +1640,7 @@ func (v *VM) unloadContext(ctx *Context) {
panic(errors.New(errMessage)) panic(errors.New(errMessage))
} }
} }
}
} }
// getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks. // getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks.
@ -1691,17 +1688,13 @@ func (v *VM) Call(offset int) {
// package. // package.
func (v *VM) call(ctx *Context, offset int) { func (v *VM) call(ctx *Context, offset int) {
v.checkInvocationStackSize() v.checkInvocationStackSize()
newCtx := ctx.Copy() newCtx := &Context{
newCtx.retCount = -1 sc: ctx.sc,
newCtx.local = nil retCount: -1,
newCtx.arguments = nil tryStack: ctx.tryStack,
// If memory for `elems` is reused, we can end up }
// with an incorrect exception context state in the caller. // New context -> new exception handlers.
newCtx.tryStack.elems = nil newCtx.tryStack.elems = ctx.tryStack.elems[len(ctx.tryStack.elems):]
initStack(&newCtx.tryStack, "exception", nil)
newCtx.NEF = ctx.NEF
// Do not clone unloading callback, new context does not require any actions to perform on unloading.
newCtx.onUnload = nil
v.istack.PushItem(newCtx) v.istack.PushItem(newCtx)
newCtx.Jump(offset) newCtx.Jump(offset)
} }
@ -1732,7 +1725,7 @@ func calcJumpOffset(ctx *Context, parameter []byte) (int, int, error) {
return 0, 0, fmt.Errorf("invalid %s parameter length: %d", curr, l) return 0, 0, fmt.Errorf("invalid %s parameter length: %d", curr, l)
} }
offset := ctx.ip + int(rOffset) offset := ctx.ip + int(rOffset)
if offset < 0 || offset > len(ctx.prog) { if offset < 0 || offset > len(ctx.sc.prog) {
return 0, 0, fmt.Errorf("invalid offset %d ip at %d", offset, ctx.ip) return 0, 0, fmt.Errorf("invalid offset %d ip at %d", offset, ctx.ip)
} }
@ -1955,7 +1948,7 @@ func bytesToPublicKey(b []byte, curve elliptic.Curve) *keys.PublicKey {
// GetCallingScriptHash implements the ScriptHashGetter interface. // GetCallingScriptHash implements the ScriptHashGetter interface.
func (v *VM) GetCallingScriptHash() util.Uint160 { func (v *VM) GetCallingScriptHash() util.Uint160 {
return v.Context().callingScriptHash return v.Context().sc.callingScriptHash
} }
// GetEntryScriptHash implements the ScriptHashGetter interface. // GetEntryScriptHash implements the ScriptHashGetter interface.