diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 3f47fd5ca..aa5989c6d 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -27,7 +28,7 @@ func LoadToken(ic *interop.Context, id int32) error { if !ctx.GetCallFlags().Has(callflag.ReadStates | callflag.AllowCall) { return errors.New("invalid call flags") } - tok := ctx.NEF.Tokens[id] + tok := ctx.GetNEF().Tokens[id] if int(tok.ParamCount) > ctx.Estack().Len() { return errors.New("stack is too small") } @@ -67,11 +68,11 @@ func Call(ic *interop.Context) error { return fmt.Errorf("method not found: %s/%d", method, len(args)) } hasReturn := md.ReturnType != smartcontract.VoidType - return callInternal(ic, cs, method, fs, hasReturn, args, !hasReturn) + return callInternal(ic, cs, method, fs, hasReturn, args, true) } func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, - hasReturn bool, args []stackitem.Item, pushNullOnUnloading bool) error { + hasReturn bool, args []stackitem.Item, isDynamic bool) error { md := cs.Manifest.ABI.GetMethod(name, len(args)) if md.Safe { f &^= (callflag.WriteStates | callflag.AllowNotify) @@ -83,12 +84,12 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl } } } - return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, pushNullOnUnloading, false) + return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, isDynamic, false) } // callExFromNative calls a contract with flags using the provided calling hash. func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, - name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, pushNullOnUnloading bool, callFromNative bool) error { + name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, isDynamic bool, callFromNative bool) error { for _, nc := range ic.Natives { if nc.Metadata().Name == nativenames.Policy { var pch = nc.(policyChecker) @@ -123,7 +124,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra if wrapped { ic.DAO = ic.DAO.GetPrivate() } - onUnload := func(commit bool) error { + onUnload := func(ctx *vm.Context, commit bool) error { if wrapped { if commit { _, err := ic.DAO.Persist() @@ -135,8 +136,13 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra } ic.DAO = baseDAO } - if pushNullOnUnloading && commit { - ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack. + if isDynamic && commit { + eLen := ctx.Estack().Len() + if eLen == 0 && ctx.NumOfReturnVals() == 0 { // No return value and none expected. + ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack. + } else if eLen > 1 { // 1 or -1 (all) retrun values expected, but only one can be returned. + return errors.New("multiple return values in a cross-contract call") + } // All other rvcount/stack length mismatches are checked by the VM. } if callFromNative && !commit { return fmt.Errorf("unhandled exception") diff --git a/pkg/core/interop/runtime/engine.go b/pkg/core/interop/runtime/engine.go index 1223d878a..73654db19 100644 --- a/pkg/core/interop/runtime/engine.go +++ b/pkg/core/interop/runtime/engine.go @@ -73,6 +73,10 @@ func Notify(ic *interop.Context) error { if len(name) > MaxEventNameLen { return fmt.Errorf("event name must be less than %d", MaxEventNameLen) } + if !ic.VM.Context().IsDeployed() { + return errors.New("notifications are not allowed in dynamic scripts") + } + // But it has to be serializable, otherwise we either have some broken // (recursive) structure inside or an interop item that can't be used // outside of the interop subsystem anyway. diff --git a/pkg/core/interop/runtime/engine_test.go b/pkg/core/interop/runtime/engine_test.go index fc3b99b85..4ddf628ae 100644 --- a/pkg/core/interop/runtime/engine_test.go +++ b/pkg/core/interop/runtime/engine_test.go @@ -12,6 +12,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" @@ -133,9 +134,12 @@ func TestLog(t *testing.T) { func TestNotify(t *testing.T) { h := random.Uint160() + caller := random.Uint160() + exe, err := nef.NewFile([]byte{1}) + require.NoError(t, err) newIC := func(name string, args interface{}) *interop.Context { ic := &interop.Context{VM: vm.New(), DAO: &dao.Simple{}} - ic.VM.LoadScriptWithHash([]byte{1}, h, callflag.NoneFlag) + ic.VM.LoadNEFMethod(exe, caller, h, callflag.NoneFlag, true, 0, -1, nil) ic.VM.Estack().PushVal(args) ic.VM.Estack().PushVal(name) return ic @@ -144,6 +148,13 @@ func TestNotify(t *testing.T) { ic := newIC(string(make([]byte, MaxEventNameLen+1)), stackitem.NewArray([]stackitem.Item{stackitem.Null{}})) require.Error(t, Notify(ic)) }) + t.Run("dynamic script", func(t *testing.T) { + ic := &interop.Context{VM: vm.New(), DAO: &dao.Simple{}} + ic.VM.LoadScriptWithHash([]byte{1}, h, callflag.NoneFlag) + ic.VM.Estack().PushVal(stackitem.NewArray([]stackitem.Item{stackitem.Make(42)})) + ic.VM.Estack().PushVal("event") + require.Error(t, Notify(ic)) + }) t.Run("recursive struct", func(t *testing.T) { arr := stackitem.NewArray([]stackitem.Item{stackitem.Null{}}) arr.Append(arr) diff --git a/pkg/core/interop/runtime/witness.go b/pkg/core/interop/runtime/witness.go index d6a01e652..09c4335d4 100644 --- a/pkg/core/interop/runtime/witness.go +++ b/pkg/core/interop/runtime/witness.go @@ -41,6 +41,10 @@ func getContractGroups(v *vm.VM, ic *interop.Context, h util.Uint160) (manifest. return manifest.Groups(cs.Manifest.Groups), nil } +func (sc scopeContext) IsCalledByEntry() bool { + return sc.VM.Context().IsCalledByEntry() +} + func (sc scopeContext) checkScriptGroups(h util.Uint160, k *keys.PublicKey) (bool, error) { groups, err := getContractGroups(sc.VM, sc.ic, h) if err != nil { @@ -69,9 +73,7 @@ func checkScope(ic *interop.Context, hash util.Uint160) (bool, error) { return true, nil } if c.Scopes&transaction.CalledByEntry != 0 { - callingScriptHash := ic.VM.GetCallingScriptHash() - entryScriptHash := ic.VM.GetEntryScriptHash() - if callingScriptHash.Equals(util.Uint160{}) || callingScriptHash == entryScriptHash { + if ic.VM.Context().IsCalledByEntry() { return true, nil } } diff --git a/pkg/core/transaction/witness_condition.go b/pkg/core/transaction/witness_condition.go index 71fcc1d33..65a73b09e 100644 --- a/pkg/core/transaction/witness_condition.go +++ b/pkg/core/transaction/witness_condition.go @@ -63,9 +63,9 @@ type WitnessCondition interface { type MatchContext interface { GetCallingScriptHash() util.Uint160 GetCurrentScriptHash() util.Uint160 - GetEntryScriptHash() util.Uint160 CallingScriptHasGroup(*keys.PublicKey) (bool, error) CurrentScriptHasGroup(*keys.PublicKey) (bool, error) + IsCalledByEntry() bool } type ( @@ -394,8 +394,7 @@ func (c ConditionCalledByEntry) Type() WitnessConditionType { // Match implements the WitnessCondition interface checking whether this condition // matches given context. func (c ConditionCalledByEntry) Match(ctx MatchContext) (bool, error) { - entry := ctx.GetEntryScriptHash() - return entry.Equals(ctx.GetCallingScriptHash()) || entry.Equals(ctx.GetCurrentScriptHash()), nil + return ctx.IsCalledByEntry(), nil } // EncodeBinary implements the WitnessCondition interface allowing to serialize condition. diff --git a/pkg/core/transaction/witness_condition_test.go b/pkg/core/transaction/witness_condition_test.go index e8ac6d4fc..2a4f56244 100644 --- a/pkg/core/transaction/witness_condition_test.go +++ b/pkg/core/transaction/witness_condition_test.go @@ -170,6 +170,9 @@ func (t *TestMC) GetCurrentScriptHash() util.Uint160 { func (t *TestMC) GetEntryScriptHash() util.Uint160 { return t.entry } +func (t *TestMC) IsCalledByEntry() bool { + return t.entry.Equals(t.calling) || t.calling.Equals(util.Uint160{}) +} func (t *TestMC) CallingScriptHasGroup(k *keys.PublicKey) (bool, error) { res, err := t.CurrentScriptHasGroup(k) return !res, err // To differentiate from current we invert the logic value. diff --git a/pkg/core/transaction/witness_scope.go b/pkg/core/transaction/witness_scope.go index 5daf15eaa..aa67eb33a 100644 --- a/pkg/core/transaction/witness_scope.go +++ b/pkg/core/transaction/witness_scope.go @@ -13,7 +13,7 @@ type WitnessScope byte const ( // None specifies that no contract was witnessed. Only sign the transaction. None WitnessScope = 0 - // CalledByEntry means that this condition must hold: EntryScriptHash == CallingScriptHash. + // CalledByEntry witness is only valid in entry script and ones directly called by it. // No params is needed, as the witness/permission/signature given on first invocation will // automatically expire if entering deeper internal invokes. This can be default safe // choice for native NEO/GAS (previously used on Neo 2 as "attach" mode). diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 955554482..6c086249f 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -16,14 +16,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) -// Context represents the current execution context of the VM. -type Context struct { - // Instruction pointer. - ip int - - // The next instruction pointer. - nextip int - +// scriptContext is a part of the Context that is shared between multiple Contexts, +// it's created when a new script is loaded into the VM while regular +// CALL/CALLL/CALLA internal invocations reuse it. +type scriptContext struct { // The raw program script. prog []byte @@ -33,12 +29,7 @@ type Context struct { // Evaluation stack pointer. estack *Stack - static *slot - local slot - arguments slot - - // Exception context stack. - tryStack Stack + static slot // Script hash of the prog. scriptHash util.Uint160 @@ -46,22 +37,43 @@ type Context struct { // Caller's contract script hash. callingScriptHash util.Uint160 + // Caller's scriptContext, if not entry. + callingContext *scriptContext + // Call flags this context was created with. callFlag callflag.CallFlag - // retCount specifies the number of return values. - retCount int // NEF represents a NEF file for the current contract. NEF *nef.File - // invTree is an invocation tree (or branch of it) for this context. + // invTree is an invocation tree (or a branch of it) for this context. invTree *invocations.Tree // onUnload is a callback that should be called after current context unloading // if no exception occurs. onUnload ContextUnloadCallback } +// Context represents the current execution context of the VM. +type Context struct { + // Instruction pointer. + ip int + + // The next instruction pointer. + nextip int + + sc *scriptContext + + local slot + arguments slot + + // Exception context stack. + tryStack Stack + + // retCount specifies the number of return values. + retCount int +} + // ContextUnloadCallback is a callback method used on context unloading from istack. -type ContextUnloadCallback func(commit bool) error +type ContextUnloadCallback func(ctx *Context, commit bool) error var errNoInstParam = errors.New("failed to read instruction parameter") @@ -74,7 +86,9 @@ func NewContext(b []byte) *Context { // return value count and initial position in script. func NewContextWithParams(b []byte, rvcount int, pos int) *Context { return &Context{ - prog: b, + sc: &scriptContext{ + prog: b, + }, retCount: rvcount, nextip: pos, } @@ -82,7 +96,7 @@ func NewContextWithParams(b []byte, rvcount int, pos int) *Context { // Estack returns the evaluation stack of c. func (c *Context) Estack() *Stack { - return c.estack + return c.sc.estack } // NextIP returns the next instruction pointer. @@ -92,7 +106,7 @@ func (c *Context) NextIP() int { // Jump unconditionally moves the next instruction pointer to the specified location. func (c *Context) Jump(pos int) { - if pos < 0 || pos >= len(c.prog) { + if pos < 0 || pos >= len(c.sc.prog) { panic("instruction offset is out of range") } c.nextip = pos @@ -105,11 +119,12 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { var err error c.ip = c.nextip - if c.ip >= len(c.prog) { + prog := c.sc.prog + if c.ip >= len(prog) { return opcode.RET, nil, nil } - var instrbyte = c.prog[c.ip] + var instrbyte = prog[c.ip] instr := opcode.Opcode(instrbyte) if !opcode.IsValid(instr) { return instr, nil, fmt.Errorf("incorrect opcode %s", instr.String()) @@ -119,24 +134,24 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { var numtoread int switch instr { case opcode.PUSHDATA1: - if c.nextip >= len(c.prog) { + if c.nextip >= len(prog) { err = errNoInstParam } else { - numtoread = int(c.prog[c.nextip]) + numtoread = int(prog[c.nextip]) c.nextip++ } case opcode.PUSHDATA2: - if c.nextip+1 >= len(c.prog) { + if c.nextip+1 >= len(prog) { err = errNoInstParam } else { - numtoread = int(binary.LittleEndian.Uint16(c.prog[c.nextip : c.nextip+2])) + numtoread = int(binary.LittleEndian.Uint16(prog[c.nextip : c.nextip+2])) c.nextip += 2 } case opcode.PUSHDATA4: - if c.nextip+3 >= len(c.prog) { + if c.nextip+3 >= len(prog) { err = errNoInstParam } else { - var n = binary.LittleEndian.Uint32(c.prog[c.nextip : c.nextip+4]) + var n = binary.LittleEndian.Uint32(prog[c.nextip : c.nextip+4]) if n > stackitem.MaxSize { return instr, nil, errors.New("parameter is too big") } @@ -166,13 +181,13 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { return instr, nil, nil } } - if c.nextip+numtoread-1 >= len(c.prog) { + if c.nextip+numtoread-1 >= len(prog) { err = errNoInstParam } if err != nil { return instr, nil, err } - parameter := c.prog[c.nextip : c.nextip+numtoread] + parameter := prog[c.nextip : c.nextip+numtoread] c.nextip += numtoread return instr, parameter, nil } @@ -184,46 +199,49 @@ func (c *Context) IP() int { // LenInstr returns the number of instructions loaded. func (c *Context) LenInstr() int { - return len(c.prog) + return len(c.sc.prog) } // CurrInstr returns the current instruction and opcode. func (c *Context) CurrInstr() (int, opcode.Opcode) { - return c.ip, opcode.Opcode(c.prog[c.ip]) + return c.ip, opcode.Opcode(c.sc.prog[c.ip]) } // NextInstr returns the next instruction and opcode. func (c *Context) NextInstr() (int, opcode.Opcode) { op := opcode.RET - if c.nextip < len(c.prog) { - op = opcode.Opcode(c.prog[c.nextip]) + if c.nextip < len(c.sc.prog) { + op = opcode.Opcode(c.sc.prog[c.nextip]) } return c.nextip, op } -// Copy returns an new exact copy of c. -func (c *Context) Copy() *Context { - ctx := new(Context) - *ctx = *c - return ctx -} - // GetCallFlags returns the calling flags which the context was created with. func (c *Context) GetCallFlags() callflag.CallFlag { - return c.callFlag + return c.sc.callFlag } // Program returns the loaded program. func (c *Context) Program() []byte { - return c.prog + return c.sc.prog } // ScriptHash returns a hash of the script in the current context. func (c *Context) ScriptHash() util.Uint160 { - if c.scriptHash.Equals(util.Uint160{}) { - c.scriptHash = hash.Hash160(c.prog) + if c.sc.scriptHash.Equals(util.Uint160{}) { + c.sc.scriptHash = hash.Hash160(c.sc.prog) } - return c.scriptHash + return c.sc.scriptHash +} + +// GetNEF returns NEF structure used by this context if it's present. +func (c *Context) GetNEF() *nef.File { + return c.sc.NEF +} + +// NumOfReturnVals returns the number of return values expected from this context. +func (c *Context) NumOfReturnVals() int { + return c.retCount } // Value implements the stackitem.Item interface. @@ -263,7 +281,7 @@ func (c *Context) Equals(s stackitem.Item) bool { } func (c *Context) atBreakPoint() bool { - for _, n := range c.breakPoints { + for _, n := range c.sc.breakPoints { if n == c.nextip { return true } @@ -277,12 +295,12 @@ func (c *Context) String() string { // IsDeployed returns whether this context contains a deployed contract. func (c *Context) IsDeployed() bool { - return c.NEF != nil + return c.sc.NEF != nil } // DumpStaticSlot returns json formatted representation of the given slot. func (c *Context) DumpStaticSlot() string { - return dumpSlot(c.static) + return dumpSlot(&c.sc.static) } // DumpLocalSlot returns json formatted representation of the given slot. @@ -316,6 +334,12 @@ func (v *VM) getContextScriptHash(n int) util.Uint160 { return ctx.ScriptHash() } +// IsCalledByEntry checks parent script contexts and return true if the current one +// is an entry script (the first loaded into the VM) or one called by it. +func (c *Context) IsCalledByEntry() bool { + return c.sc.callingContext == nil || c.sc.callingContext.callingContext == nil +} + // PushContextScriptHash pushes the script hash of the // invocation stack element number n to the evaluation stack. func (v *VM) PushContextScriptHash(n int) error { diff --git a/pkg/vm/interop.go b/pkg/vm/interop.go index 4885f255e..27156b31f 100644 --- a/pkg/vm/interop.go +++ b/pkg/vm/interop.go @@ -36,8 +36,9 @@ func defaultSyscallHandler(v *VM, id uint32) error { return errors.New("syscall not found") } d := defaultVMInterops[n] - if !v.Context().callFlag.Has(d.RequiredFlags) { - return fmt.Errorf("missing call flags: %05b vs %05b", v.Context().callFlag, d.RequiredFlags) + ctxFlag := v.Context().sc.callFlag + if !ctxFlag.Has(d.RequiredFlags) { + return fmt.Errorf("missing call flags: %05b vs %05b", ctxFlag, d.RequiredFlags) } return d.Func(v) } diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index 86a68540d..8256a6840 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -115,7 +115,7 @@ func testSyscallHandler(v *VM, id uint32) error { case 0x77777777: v.Estack().PushVal(stackitem.NewInterop(new(int))) case 0x66666666: - if !v.Context().callFlag.Has(callflag.ReadOnly) { + if !v.Context().sc.callFlag.Has(callflag.ReadOnly) { return errors.New("invalid call flags") } v.Estack().PushVal(stackitem.NewInterop(new(int))) @@ -167,14 +167,14 @@ func testFile(t *testing.T, filename string) { if len(result.InvocationStack) > 0 { for i, s := range result.InvocationStack { ctx := vm.istack.Peek(i).Value().(*Context) - if ctx.nextip < len(ctx.prog) { + if ctx.nextip < len(ctx.sc.prog) { require.Equal(t, s.InstructionPointer, ctx.nextip) op, err := opcode.FromString(s.Instruction) require.NoError(t, err) - require.Equal(t, op, opcode.Opcode(ctx.prog[ctx.nextip])) + require.Equal(t, op, opcode.Opcode(ctx.sc.prog[ctx.nextip])) } compareStacks(t, s.EStack, vm.estack) - compareSlots(t, s.StaticFields, ctx.static) + compareSlots(t, s.StaticFields, ctx.sc.static) } } @@ -240,8 +240,8 @@ func compareStacks(t *testing.T, expected []vmUTStackItem, actual *Stack) { compareItemArrays(t, expected, actual.Len(), func(i int) stackitem.Item { return actual.Peek(i).Item() }) } -func compareSlots(t *testing.T, expected []vmUTStackItem, actual *slot) { - if (actual == nil || *actual == nil) && len(expected) == 0 { +func compareSlots(t *testing.T, expected []vmUTStackItem, actual slot) { + if actual == nil && len(expected) == 0 { return } require.NotNil(t, actual) diff --git a/pkg/vm/opcodebench_test.go b/pkg/vm/opcodebench_test.go index 3ff86c09a..1494dbd9a 100644 --- a/pkg/vm/opcodebench_test.go +++ b/pkg/vm/opcodebench_test.go @@ -56,7 +56,7 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int, return nil } if sslot != 0 { - v.Context().static.init(sslot, &v.refs) + v.Context().sc.static.init(sslot, &v.refs) } if slotloc != 0 && slotarg != 0 { v.Context().local.init(slotloc, &v.refs) diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index aea077e74..44e3f50d5 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -171,9 +171,7 @@ func (v *VM) PrintOps(out io.Writer) { w := tabwriter.NewWriter(out, 0, 0, 4, ' ', 0) fmt.Fprintln(w, "INDEX\tOPCODE\tPARAMETER") realctx := v.Context() - ctx := realctx.Copy() - ctx.ip = 0 - ctx.nextip = 0 + ctx := &Context{sc: realctx.sc} for { cursor := "" instr, parameter, err := ctx.Next() @@ -228,7 +226,7 @@ func (v *VM) PrintOps(out io.Writer) { } fmt.Fprintf(w, "%d\t%s\t%s%s\n", ctx.ip, instr, desc, cursor) - if ctx.nextip >= len(ctx.prog) { + if ctx.nextip >= len(ctx.sc.prog) { break } } @@ -246,7 +244,7 @@ func getOffsetDesc(ctx *Context, parameter []byte) string { // AddBreakPoint adds a breakpoint to the current context. func (v *VM) AddBreakPoint(n int) { ctx := v.Context() - ctx.breakPoints = append(ctx.breakPoints, n) + ctx.sc.breakPoints = append(ctx.sc.breakPoints, n) } // AddBreakPointRel adds a breakpoint relative to the current @@ -337,31 +335,31 @@ func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160 // It should be used for calling from native contracts. func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160, hash util.Uint160, f callflag.CallFlag, rvcount int, offset int, onContextUnload ContextUnloadCallback) { - var sl slot - v.checkInvocationStackSize() ctx := NewContextWithParams(b, rvcount, offset) if rvcount != -1 || v.estack.Len() != 0 { v.estack = subStack(v.estack) } - ctx.estack = v.estack + parent := v.Context() + ctx.sc.estack = v.estack initStack(&ctx.tryStack, "exception", nil) - ctx.callFlag = f - ctx.static = &sl - ctx.scriptHash = hash - ctx.callingScriptHash = caller - ctx.NEF = exe + ctx.sc.callFlag = f + ctx.sc.scriptHash = hash + ctx.sc.callingScriptHash = caller + if parent != nil { + ctx.sc.callingContext = parent.sc + } + ctx.sc.NEF = exe if v.invTree != nil { curTree := v.invTree - parent := v.Context() if parent != nil { - curTree = parent.invTree + curTree = parent.sc.invTree } newTree := &invocations.Tree{Current: ctx.ScriptHash()} curTree.Calls = append(curTree.Calls, newTree) - ctx.invTree = newTree + ctx.sc.invTree = newTree } - ctx.onUnload = onContextUnload + ctx.sc.onUnload = onContextUnload v.istack.PushItem(ctx) } @@ -481,7 +479,7 @@ func (v *VM) StepInto() error { return nil } - if ctx != nil && ctx.prog != nil { + if ctx != nil && ctx.sc.prog != nil { op, param, err := ctx.Next() if err != nil { v.state = vmstate.Fault @@ -584,7 +582,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } }() - if v.getPrice != nil && ctx.ip < len(ctx.prog) { + if v.getPrice != nil && ctx.ip < len(ctx.sc.prog) { v.gasConsumed += v.getPrice(op, parameter) if v.GasLimit >= 0 && v.gasConsumed > v.GasLimit { panic("gas limit is exceeded") @@ -610,7 +608,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.PUSHA: n := getJumpOffset(ctx, parameter) - ptr := stackitem.NewPointerWithHash(n, ctx.prog, ctx.ScriptHash()) + ptr := stackitem.NewPointerWithHash(n, ctx.sc.prog, ctx.ScriptHash()) v.estack.PushItem(ptr) case opcode.PUSHNULL: @@ -637,7 +635,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if parameter[0] == 0 { panic("zero argument") } - ctx.static.init(int(parameter[0]), &v.refs) + ctx.sc.static.init(int(parameter[0]), &v.refs) case opcode.INITSLOT: if ctx.local != nil || ctx.arguments != nil { @@ -658,20 +656,20 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } case opcode.LDSFLD0, opcode.LDSFLD1, opcode.LDSFLD2, opcode.LDSFLD3, opcode.LDSFLD4, opcode.LDSFLD5, opcode.LDSFLD6: - item := ctx.static.Get(int(op - opcode.LDSFLD0)) + item := ctx.sc.static.Get(int(op - opcode.LDSFLD0)) v.estack.PushItem(item) case opcode.LDSFLD: - item := ctx.static.Get(int(parameter[0])) + item := ctx.sc.static.Get(int(parameter[0])) v.estack.PushItem(item) case opcode.STSFLD0, opcode.STSFLD1, opcode.STSFLD2, opcode.STSFLD3, opcode.STSFLD4, opcode.STSFLD5, opcode.STSFLD6: item := v.estack.Pop().Item() - ctx.static.Set(int(op-opcode.STSFLD0), item, &v.refs) + ctx.sc.static.Set(int(op-opcode.STSFLD0), item, &v.refs) case opcode.STSFLD: item := v.estack.Pop().Item() - ctx.static.Set(int(parameter[0]), item, &v.refs) + ctx.sc.static.Set(int(parameter[0]), item, &v.refs) case opcode.LDLOC0, opcode.LDLOC1, opcode.LDLOC2, opcode.LDLOC3, opcode.LDLOC4, opcode.LDLOC5, opcode.LDLOC6: item := ctx.local.Get(int(op - opcode.LDLOC0)) @@ -1475,7 +1473,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro break } - newEstack := v.Context().estack + newEstack := v.Context().sc.estack if oldEstack != newEstack { if oldCtx.retCount >= 0 && oldEstack.Len() != oldCtx.retCount { panic(fmt.Errorf("invalid return values count: expected %d, got %d", @@ -1631,17 +1629,19 @@ func (v *VM) unloadContext(ctx *Context) { ctx.arguments.ClearRefs(&v.refs) } currCtx := v.Context() - if ctx.static != nil && (currCtx == nil || ctx.static != currCtx.static) { - ctx.static.ClearRefs(&v.refs) - } - if ctx.onUnload != nil { - err := ctx.onUnload(v.uncaughtException == nil) - if err != nil { - errMessage := fmt.Sprintf("context unload callback failed: %s", err) - if v.uncaughtException != nil { - errMessage = fmt.Sprintf("%s, uncaught exception: %s", errMessage, v.uncaughtException) + if currCtx == nil || ctx.sc != currCtx.sc { + if ctx.sc.static != nil { + ctx.sc.static.ClearRefs(&v.refs) + } + if ctx.sc.onUnload != nil { + err := ctx.sc.onUnload(ctx, v.uncaughtException == nil) + if err != nil { + errMessage := fmt.Sprintf("context unload callback failed: %s", err) + if v.uncaughtException != nil { + errMessage = fmt.Sprintf("%s, uncaught exception: %s", errMessage, v.uncaughtException) + } + panic(errors.New(errMessage)) } - panic(errors.New(errMessage)) } } } @@ -1691,17 +1691,13 @@ func (v *VM) Call(offset int) { // package. func (v *VM) call(ctx *Context, offset int) { v.checkInvocationStackSize() - newCtx := ctx.Copy() - newCtx.retCount = -1 - newCtx.local = nil - newCtx.arguments = nil - // If memory for `elems` is reused, we can end up - // with an incorrect exception context state in the caller. - newCtx.tryStack.elems = nil - initStack(&newCtx.tryStack, "exception", nil) - newCtx.NEF = ctx.NEF - // Do not clone unloading callback, new context does not require any actions to perform on unloading. - newCtx.onUnload = nil + newCtx := &Context{ + sc: ctx.sc, + retCount: -1, + tryStack: ctx.tryStack, + } + // New context -> new exception handlers. + newCtx.tryStack.elems = ctx.tryStack.elems[len(ctx.tryStack.elems):] v.istack.PushItem(newCtx) newCtx.Jump(offset) } @@ -1732,7 +1728,7 @@ func calcJumpOffset(ctx *Context, parameter []byte) (int, int, error) { return 0, 0, fmt.Errorf("invalid %s parameter length: %d", curr, l) } offset := ctx.ip + int(rOffset) - if offset < 0 || offset > len(ctx.prog) { + if offset < 0 || offset > len(ctx.sc.prog) { return 0, 0, fmt.Errorf("invalid offset %d ip at %d", offset, ctx.ip) } @@ -1955,7 +1951,7 @@ func bytesToPublicKey(b []byte, curve elliptic.Curve) *keys.PublicKey { // GetCallingScriptHash implements the ScriptHashGetter interface. func (v *VM) GetCallingScriptHash() util.Uint160 { - return v.Context().callingScriptHash + return v.Context().sc.callingScriptHash } // GetEntryScriptHash implements the ScriptHashGetter interface.