From e5c59f8dddc1b0798d9eaf2ea7b4089170d92a40 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 27 Jul 2022 14:49:53 +0300 Subject: [PATCH 1/4] interop/runtime: disable notifications in dynamic scripts That are only entry scripts today. See neo-project/neo#2796. --- pkg/core/interop/runtime/engine.go | 4 ++++ pkg/core/interop/runtime/engine_test.go | 13 ++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pkg/core/interop/runtime/engine.go b/pkg/core/interop/runtime/engine.go index 1223d878a..81dae7a2c 100644 --- a/pkg/core/interop/runtime/engine.go +++ b/pkg/core/interop/runtime/engine.go @@ -73,6 +73,10 @@ func Notify(ic *interop.Context) error { if len(name) > MaxEventNameLen { return fmt.Errorf("event name must be less than %d", MaxEventNameLen) } + if ic.VM.Context().NEF == nil { + return errors.New("notifications are not allowed in dynamic scripts") + } + // But it has to be serializable, otherwise we either have some broken // (recursive) structure inside or an interop item that can't be used // outside of the interop subsystem anyway. diff --git a/pkg/core/interop/runtime/engine_test.go b/pkg/core/interop/runtime/engine_test.go index fc3b99b85..4ddf628ae 100644 --- a/pkg/core/interop/runtime/engine_test.go +++ b/pkg/core/interop/runtime/engine_test.go @@ -12,6 +12,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" @@ -133,9 +134,12 @@ func TestLog(t *testing.T) { func TestNotify(t *testing.T) { h := random.Uint160() + caller := random.Uint160() + exe, err := nef.NewFile([]byte{1}) + require.NoError(t, err) newIC := func(name string, args interface{}) *interop.Context { ic := &interop.Context{VM: vm.New(), DAO: &dao.Simple{}} - ic.VM.LoadScriptWithHash([]byte{1}, h, callflag.NoneFlag) + ic.VM.LoadNEFMethod(exe, caller, h, callflag.NoneFlag, true, 0, -1, nil) ic.VM.Estack().PushVal(args) ic.VM.Estack().PushVal(name) return ic @@ -144,6 +148,13 @@ func TestNotify(t *testing.T) { ic := newIC(string(make([]byte, MaxEventNameLen+1)), stackitem.NewArray([]stackitem.Item{stackitem.Null{}})) require.Error(t, Notify(ic)) }) + t.Run("dynamic script", func(t *testing.T) { + ic := &interop.Context{VM: vm.New(), DAO: &dao.Simple{}} + ic.VM.LoadScriptWithHash([]byte{1}, h, callflag.NoneFlag) + ic.VM.Estack().PushVal(stackitem.NewArray([]stackitem.Item{stackitem.Make(42)})) + ic.VM.Estack().PushVal("event") + require.Error(t, Notify(ic)) + }) t.Run("recursive struct", func(t *testing.T) { arr := stackitem.NewArray([]stackitem.Item{stackitem.Null{}}) arr.Append(arr) From 13f5fdbe8a2046be9eb2a0ded3f0b739b74ca304 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Thu, 4 Aug 2022 16:15:51 +0300 Subject: [PATCH 2/4] vm: extract shared parts of the Context Local calls reuse them, cross-contract calls create new ones. This allows to avoid some allocations and use a little less memory. --- pkg/core/interop/contract/call.go | 2 +- pkg/core/interop/runtime/engine.go | 2 +- pkg/vm/context.go | 108 ++++++++++++++++------------- pkg/vm/interop.go | 5 +- pkg/vm/json_test.go | 12 ++-- pkg/vm/opcodebench_test.go | 2 +- pkg/vm/vm.go | 89 +++++++++++------------- 7 files changed, 112 insertions(+), 108 deletions(-) diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 3f47fd5ca..279bbb5e3 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -27,7 +27,7 @@ func LoadToken(ic *interop.Context, id int32) error { if !ctx.GetCallFlags().Has(callflag.ReadStates | callflag.AllowCall) { return errors.New("invalid call flags") } - tok := ctx.NEF.Tokens[id] + tok := ctx.GetNEF().Tokens[id] if int(tok.ParamCount) > ctx.Estack().Len() { return errors.New("stack is too small") } diff --git a/pkg/core/interop/runtime/engine.go b/pkg/core/interop/runtime/engine.go index 81dae7a2c..73654db19 100644 --- a/pkg/core/interop/runtime/engine.go +++ b/pkg/core/interop/runtime/engine.go @@ -73,7 +73,7 @@ func Notify(ic *interop.Context) error { if len(name) > MaxEventNameLen { return fmt.Errorf("event name must be less than %d", MaxEventNameLen) } - if ic.VM.Context().NEF == nil { + if !ic.VM.Context().IsDeployed() { return errors.New("notifications are not allowed in dynamic scripts") } diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 955554482..aaf74b01e 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -16,14 +16,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) -// Context represents the current execution context of the VM. -type Context struct { - // Instruction pointer. - ip int - - // The next instruction pointer. - nextip int - +// scriptContext is a part of the Context that is shared between multiple Contexts, +// it's created when a new script is loaded into the VM while regular +// CALL/CALLL/CALLA internal invocations reuse it. +type scriptContext struct { // The raw program script. prog []byte @@ -33,12 +29,7 @@ type Context struct { // Evaluation stack pointer. estack *Stack - static *slot - local slot - arguments slot - - // Exception context stack. - tryStack Stack + static slot // Script hash of the prog. scriptHash util.Uint160 @@ -49,17 +40,35 @@ type Context struct { // Call flags this context was created with. callFlag callflag.CallFlag - // retCount specifies the number of return values. - retCount int // NEF represents a NEF file for the current contract. NEF *nef.File - // invTree is an invocation tree (or branch of it) for this context. + // invTree is an invocation tree (or a branch of it) for this context. invTree *invocations.Tree // onUnload is a callback that should be called after current context unloading // if no exception occurs. onUnload ContextUnloadCallback } +// Context represents the current execution context of the VM. +type Context struct { + // Instruction pointer. + ip int + + // The next instruction pointer. + nextip int + + sc *scriptContext + + local slot + arguments slot + + // Exception context stack. + tryStack Stack + + // retCount specifies the number of return values. + retCount int +} + // ContextUnloadCallback is a callback method used on context unloading from istack. type ContextUnloadCallback func(commit bool) error @@ -74,7 +83,9 @@ func NewContext(b []byte) *Context { // return value count and initial position in script. func NewContextWithParams(b []byte, rvcount int, pos int) *Context { return &Context{ - prog: b, + sc: &scriptContext{ + prog: b, + }, retCount: rvcount, nextip: pos, } @@ -82,7 +93,7 @@ func NewContextWithParams(b []byte, rvcount int, pos int) *Context { // Estack returns the evaluation stack of c. func (c *Context) Estack() *Stack { - return c.estack + return c.sc.estack } // NextIP returns the next instruction pointer. @@ -92,7 +103,7 @@ func (c *Context) NextIP() int { // Jump unconditionally moves the next instruction pointer to the specified location. func (c *Context) Jump(pos int) { - if pos < 0 || pos >= len(c.prog) { + if pos < 0 || pos >= len(c.sc.prog) { panic("instruction offset is out of range") } c.nextip = pos @@ -105,11 +116,12 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { var err error c.ip = c.nextip - if c.ip >= len(c.prog) { + prog := c.sc.prog + if c.ip >= len(prog) { return opcode.RET, nil, nil } - var instrbyte = c.prog[c.ip] + var instrbyte = prog[c.ip] instr := opcode.Opcode(instrbyte) if !opcode.IsValid(instr) { return instr, nil, fmt.Errorf("incorrect opcode %s", instr.String()) @@ -119,24 +131,24 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { var numtoread int switch instr { case opcode.PUSHDATA1: - if c.nextip >= len(c.prog) { + if c.nextip >= len(prog) { err = errNoInstParam } else { - numtoread = int(c.prog[c.nextip]) + numtoread = int(prog[c.nextip]) c.nextip++ } case opcode.PUSHDATA2: - if c.nextip+1 >= len(c.prog) { + if c.nextip+1 >= len(prog) { err = errNoInstParam } else { - numtoread = int(binary.LittleEndian.Uint16(c.prog[c.nextip : c.nextip+2])) + numtoread = int(binary.LittleEndian.Uint16(prog[c.nextip : c.nextip+2])) c.nextip += 2 } case opcode.PUSHDATA4: - if c.nextip+3 >= len(c.prog) { + if c.nextip+3 >= len(prog) { err = errNoInstParam } else { - var n = binary.LittleEndian.Uint32(c.prog[c.nextip : c.nextip+4]) + var n = binary.LittleEndian.Uint32(prog[c.nextip : c.nextip+4]) if n > stackitem.MaxSize { return instr, nil, errors.New("parameter is too big") } @@ -166,13 +178,13 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { return instr, nil, nil } } - if c.nextip+numtoread-1 >= len(c.prog) { + if c.nextip+numtoread-1 >= len(prog) { err = errNoInstParam } if err != nil { return instr, nil, err } - parameter := c.prog[c.nextip : c.nextip+numtoread] + parameter := prog[c.nextip : c.nextip+numtoread] c.nextip += numtoread return instr, parameter, nil } @@ -184,46 +196,44 @@ func (c *Context) IP() int { // LenInstr returns the number of instructions loaded. func (c *Context) LenInstr() int { - return len(c.prog) + return len(c.sc.prog) } // CurrInstr returns the current instruction and opcode. func (c *Context) CurrInstr() (int, opcode.Opcode) { - return c.ip, opcode.Opcode(c.prog[c.ip]) + return c.ip, opcode.Opcode(c.sc.prog[c.ip]) } // NextInstr returns the next instruction and opcode. func (c *Context) NextInstr() (int, opcode.Opcode) { op := opcode.RET - if c.nextip < len(c.prog) { - op = opcode.Opcode(c.prog[c.nextip]) + if c.nextip < len(c.sc.prog) { + op = opcode.Opcode(c.sc.prog[c.nextip]) } return c.nextip, op } -// Copy returns an new exact copy of c. -func (c *Context) Copy() *Context { - ctx := new(Context) - *ctx = *c - return ctx -} - // GetCallFlags returns the calling flags which the context was created with. func (c *Context) GetCallFlags() callflag.CallFlag { - return c.callFlag + return c.sc.callFlag } // Program returns the loaded program. func (c *Context) Program() []byte { - return c.prog + return c.sc.prog } // ScriptHash returns a hash of the script in the current context. func (c *Context) ScriptHash() util.Uint160 { - if c.scriptHash.Equals(util.Uint160{}) { - c.scriptHash = hash.Hash160(c.prog) + if c.sc.scriptHash.Equals(util.Uint160{}) { + c.sc.scriptHash = hash.Hash160(c.sc.prog) } - return c.scriptHash + return c.sc.scriptHash +} + +// GetNEF returns NEF structure used by this context if it's present. +func (c *Context) GetNEF() *nef.File { + return c.sc.NEF } // Value implements the stackitem.Item interface. @@ -263,7 +273,7 @@ func (c *Context) Equals(s stackitem.Item) bool { } func (c *Context) atBreakPoint() bool { - for _, n := range c.breakPoints { + for _, n := range c.sc.breakPoints { if n == c.nextip { return true } @@ -277,12 +287,12 @@ func (c *Context) String() string { // IsDeployed returns whether this context contains a deployed contract. func (c *Context) IsDeployed() bool { - return c.NEF != nil + return c.sc.NEF != nil } // DumpStaticSlot returns json formatted representation of the given slot. func (c *Context) DumpStaticSlot() string { - return dumpSlot(c.static) + return dumpSlot(&c.sc.static) } // DumpLocalSlot returns json formatted representation of the given slot. diff --git a/pkg/vm/interop.go b/pkg/vm/interop.go index 4885f255e..27156b31f 100644 --- a/pkg/vm/interop.go +++ b/pkg/vm/interop.go @@ -36,8 +36,9 @@ func defaultSyscallHandler(v *VM, id uint32) error { return errors.New("syscall not found") } d := defaultVMInterops[n] - if !v.Context().callFlag.Has(d.RequiredFlags) { - return fmt.Errorf("missing call flags: %05b vs %05b", v.Context().callFlag, d.RequiredFlags) + ctxFlag := v.Context().sc.callFlag + if !ctxFlag.Has(d.RequiredFlags) { + return fmt.Errorf("missing call flags: %05b vs %05b", ctxFlag, d.RequiredFlags) } return d.Func(v) } diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index 86a68540d..8256a6840 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -115,7 +115,7 @@ func testSyscallHandler(v *VM, id uint32) error { case 0x77777777: v.Estack().PushVal(stackitem.NewInterop(new(int))) case 0x66666666: - if !v.Context().callFlag.Has(callflag.ReadOnly) { + if !v.Context().sc.callFlag.Has(callflag.ReadOnly) { return errors.New("invalid call flags") } v.Estack().PushVal(stackitem.NewInterop(new(int))) @@ -167,14 +167,14 @@ func testFile(t *testing.T, filename string) { if len(result.InvocationStack) > 0 { for i, s := range result.InvocationStack { ctx := vm.istack.Peek(i).Value().(*Context) - if ctx.nextip < len(ctx.prog) { + if ctx.nextip < len(ctx.sc.prog) { require.Equal(t, s.InstructionPointer, ctx.nextip) op, err := opcode.FromString(s.Instruction) require.NoError(t, err) - require.Equal(t, op, opcode.Opcode(ctx.prog[ctx.nextip])) + require.Equal(t, op, opcode.Opcode(ctx.sc.prog[ctx.nextip])) } compareStacks(t, s.EStack, vm.estack) - compareSlots(t, s.StaticFields, ctx.static) + compareSlots(t, s.StaticFields, ctx.sc.static) } } @@ -240,8 +240,8 @@ func compareStacks(t *testing.T, expected []vmUTStackItem, actual *Stack) { compareItemArrays(t, expected, actual.Len(), func(i int) stackitem.Item { return actual.Peek(i).Item() }) } -func compareSlots(t *testing.T, expected []vmUTStackItem, actual *slot) { - if (actual == nil || *actual == nil) && len(expected) == 0 { +func compareSlots(t *testing.T, expected []vmUTStackItem, actual slot) { + if actual == nil && len(expected) == 0 { return } require.NotNil(t, actual) diff --git a/pkg/vm/opcodebench_test.go b/pkg/vm/opcodebench_test.go index 3ff86c09a..1494dbd9a 100644 --- a/pkg/vm/opcodebench_test.go +++ b/pkg/vm/opcodebench_test.go @@ -56,7 +56,7 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int, return nil } if sslot != 0 { - v.Context().static.init(sslot, &v.refs) + v.Context().sc.static.init(sslot, &v.refs) } if slotloc != 0 && slotarg != 0 { v.Context().local.init(slotloc, &v.refs) diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index aea077e74..47c68a772 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -171,9 +171,7 @@ func (v *VM) PrintOps(out io.Writer) { w := tabwriter.NewWriter(out, 0, 0, 4, ' ', 0) fmt.Fprintln(w, "INDEX\tOPCODE\tPARAMETER") realctx := v.Context() - ctx := realctx.Copy() - ctx.ip = 0 - ctx.nextip = 0 + ctx := &Context{sc: realctx.sc} for { cursor := "" instr, parameter, err := ctx.Next() @@ -228,7 +226,7 @@ func (v *VM) PrintOps(out io.Writer) { } fmt.Fprintf(w, "%d\t%s\t%s%s\n", ctx.ip, instr, desc, cursor) - if ctx.nextip >= len(ctx.prog) { + if ctx.nextip >= len(ctx.sc.prog) { break } } @@ -246,7 +244,7 @@ func getOffsetDesc(ctx *Context, parameter []byte) string { // AddBreakPoint adds a breakpoint to the current context. func (v *VM) AddBreakPoint(n int) { ctx := v.Context() - ctx.breakPoints = append(ctx.breakPoints, n) + ctx.sc.breakPoints = append(ctx.sc.breakPoints, n) } // AddBreakPointRel adds a breakpoint relative to the current @@ -337,31 +335,28 @@ func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160 // It should be used for calling from native contracts. func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160, hash util.Uint160, f callflag.CallFlag, rvcount int, offset int, onContextUnload ContextUnloadCallback) { - var sl slot - v.checkInvocationStackSize() ctx := NewContextWithParams(b, rvcount, offset) if rvcount != -1 || v.estack.Len() != 0 { v.estack = subStack(v.estack) } - ctx.estack = v.estack + ctx.sc.estack = v.estack initStack(&ctx.tryStack, "exception", nil) - ctx.callFlag = f - ctx.static = &sl - ctx.scriptHash = hash - ctx.callingScriptHash = caller - ctx.NEF = exe + ctx.sc.callFlag = f + ctx.sc.scriptHash = hash + ctx.sc.callingScriptHash = caller + ctx.sc.NEF = exe if v.invTree != nil { curTree := v.invTree parent := v.Context() if parent != nil { - curTree = parent.invTree + curTree = parent.sc.invTree } newTree := &invocations.Tree{Current: ctx.ScriptHash()} curTree.Calls = append(curTree.Calls, newTree) - ctx.invTree = newTree + ctx.sc.invTree = newTree } - ctx.onUnload = onContextUnload + ctx.sc.onUnload = onContextUnload v.istack.PushItem(ctx) } @@ -481,7 +476,7 @@ func (v *VM) StepInto() error { return nil } - if ctx != nil && ctx.prog != nil { + if ctx != nil && ctx.sc.prog != nil { op, param, err := ctx.Next() if err != nil { v.state = vmstate.Fault @@ -584,7 +579,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } }() - if v.getPrice != nil && ctx.ip < len(ctx.prog) { + if v.getPrice != nil && ctx.ip < len(ctx.sc.prog) { v.gasConsumed += v.getPrice(op, parameter) if v.GasLimit >= 0 && v.gasConsumed > v.GasLimit { panic("gas limit is exceeded") @@ -610,7 +605,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.PUSHA: n := getJumpOffset(ctx, parameter) - ptr := stackitem.NewPointerWithHash(n, ctx.prog, ctx.ScriptHash()) + ptr := stackitem.NewPointerWithHash(n, ctx.sc.prog, ctx.ScriptHash()) v.estack.PushItem(ptr) case opcode.PUSHNULL: @@ -637,7 +632,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if parameter[0] == 0 { panic("zero argument") } - ctx.static.init(int(parameter[0]), &v.refs) + ctx.sc.static.init(int(parameter[0]), &v.refs) case opcode.INITSLOT: if ctx.local != nil || ctx.arguments != nil { @@ -658,20 +653,20 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } case opcode.LDSFLD0, opcode.LDSFLD1, opcode.LDSFLD2, opcode.LDSFLD3, opcode.LDSFLD4, opcode.LDSFLD5, opcode.LDSFLD6: - item := ctx.static.Get(int(op - opcode.LDSFLD0)) + item := ctx.sc.static.Get(int(op - opcode.LDSFLD0)) v.estack.PushItem(item) case opcode.LDSFLD: - item := ctx.static.Get(int(parameter[0])) + item := ctx.sc.static.Get(int(parameter[0])) v.estack.PushItem(item) case opcode.STSFLD0, opcode.STSFLD1, opcode.STSFLD2, opcode.STSFLD3, opcode.STSFLD4, opcode.STSFLD5, opcode.STSFLD6: item := v.estack.Pop().Item() - ctx.static.Set(int(op-opcode.STSFLD0), item, &v.refs) + ctx.sc.static.Set(int(op-opcode.STSFLD0), item, &v.refs) case opcode.STSFLD: item := v.estack.Pop().Item() - ctx.static.Set(int(parameter[0]), item, &v.refs) + ctx.sc.static.Set(int(parameter[0]), item, &v.refs) case opcode.LDLOC0, opcode.LDLOC1, opcode.LDLOC2, opcode.LDLOC3, opcode.LDLOC4, opcode.LDLOC5, opcode.LDLOC6: item := ctx.local.Get(int(op - opcode.LDLOC0)) @@ -1475,7 +1470,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro break } - newEstack := v.Context().estack + newEstack := v.Context().sc.estack if oldEstack != newEstack { if oldCtx.retCount >= 0 && oldEstack.Len() != oldCtx.retCount { panic(fmt.Errorf("invalid return values count: expected %d, got %d", @@ -1631,17 +1626,19 @@ func (v *VM) unloadContext(ctx *Context) { ctx.arguments.ClearRefs(&v.refs) } currCtx := v.Context() - if ctx.static != nil && (currCtx == nil || ctx.static != currCtx.static) { - ctx.static.ClearRefs(&v.refs) - } - if ctx.onUnload != nil { - err := ctx.onUnload(v.uncaughtException == nil) - if err != nil { - errMessage := fmt.Sprintf("context unload callback failed: %s", err) - if v.uncaughtException != nil { - errMessage = fmt.Sprintf("%s, uncaught exception: %s", errMessage, v.uncaughtException) + if currCtx == nil || ctx.sc != currCtx.sc { + if ctx.sc.static != nil { + ctx.sc.static.ClearRefs(&v.refs) + } + if ctx.sc.onUnload != nil { + err := ctx.sc.onUnload(v.uncaughtException == nil) + if err != nil { + errMessage := fmt.Sprintf("context unload callback failed: %s", err) + if v.uncaughtException != nil { + errMessage = fmt.Sprintf("%s, uncaught exception: %s", errMessage, v.uncaughtException) + } + panic(errors.New(errMessage)) } - panic(errors.New(errMessage)) } } } @@ -1691,17 +1688,13 @@ func (v *VM) Call(offset int) { // package. func (v *VM) call(ctx *Context, offset int) { v.checkInvocationStackSize() - newCtx := ctx.Copy() - newCtx.retCount = -1 - newCtx.local = nil - newCtx.arguments = nil - // If memory for `elems` is reused, we can end up - // with an incorrect exception context state in the caller. - newCtx.tryStack.elems = nil - initStack(&newCtx.tryStack, "exception", nil) - newCtx.NEF = ctx.NEF - // Do not clone unloading callback, new context does not require any actions to perform on unloading. - newCtx.onUnload = nil + newCtx := &Context{ + sc: ctx.sc, + retCount: -1, + tryStack: ctx.tryStack, + } + // New context -> new exception handlers. + newCtx.tryStack.elems = ctx.tryStack.elems[len(ctx.tryStack.elems):] v.istack.PushItem(newCtx) newCtx.Jump(offset) } @@ -1732,7 +1725,7 @@ func calcJumpOffset(ctx *Context, parameter []byte) (int, int, error) { return 0, 0, fmt.Errorf("invalid %s parameter length: %d", curr, l) } offset := ctx.ip + int(rOffset) - if offset < 0 || offset > len(ctx.prog) { + if offset < 0 || offset > len(ctx.sc.prog) { return 0, 0, fmt.Errorf("invalid offset %d ip at %d", offset, ctx.ip) } @@ -1955,7 +1948,7 @@ func bytesToPublicKey(b []byte, curve elliptic.Curve) *keys.PublicKey { // GetCallingScriptHash implements the ScriptHashGetter interface. func (v *VM) GetCallingScriptHash() util.Uint160 { - return v.Context().callingScriptHash + return v.Context().sc.callingScriptHash } // GetEntryScriptHash implements the ScriptHashGetter interface. From 99e2681d3afa037784d7702212a01e8d8c689442 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Thu, 4 Aug 2022 16:35:02 +0300 Subject: [PATCH 3/4] interop/vm: use more robust CalledByEntry check Directly check contexts. --- pkg/core/interop/runtime/witness.go | 8 +++++--- pkg/core/transaction/witness_condition.go | 5 ++--- pkg/core/transaction/witness_condition_test.go | 3 +++ pkg/core/transaction/witness_scope.go | 2 +- pkg/vm/context.go | 9 +++++++++ pkg/vm/vm.go | 5 ++++- 6 files changed, 24 insertions(+), 8 deletions(-) diff --git a/pkg/core/interop/runtime/witness.go b/pkg/core/interop/runtime/witness.go index d6a01e652..09c4335d4 100644 --- a/pkg/core/interop/runtime/witness.go +++ b/pkg/core/interop/runtime/witness.go @@ -41,6 +41,10 @@ func getContractGroups(v *vm.VM, ic *interop.Context, h util.Uint160) (manifest. return manifest.Groups(cs.Manifest.Groups), nil } +func (sc scopeContext) IsCalledByEntry() bool { + return sc.VM.Context().IsCalledByEntry() +} + func (sc scopeContext) checkScriptGroups(h util.Uint160, k *keys.PublicKey) (bool, error) { groups, err := getContractGroups(sc.VM, sc.ic, h) if err != nil { @@ -69,9 +73,7 @@ func checkScope(ic *interop.Context, hash util.Uint160) (bool, error) { return true, nil } if c.Scopes&transaction.CalledByEntry != 0 { - callingScriptHash := ic.VM.GetCallingScriptHash() - entryScriptHash := ic.VM.GetEntryScriptHash() - if callingScriptHash.Equals(util.Uint160{}) || callingScriptHash == entryScriptHash { + if ic.VM.Context().IsCalledByEntry() { return true, nil } } diff --git a/pkg/core/transaction/witness_condition.go b/pkg/core/transaction/witness_condition.go index 71fcc1d33..65a73b09e 100644 --- a/pkg/core/transaction/witness_condition.go +++ b/pkg/core/transaction/witness_condition.go @@ -63,9 +63,9 @@ type WitnessCondition interface { type MatchContext interface { GetCallingScriptHash() util.Uint160 GetCurrentScriptHash() util.Uint160 - GetEntryScriptHash() util.Uint160 CallingScriptHasGroup(*keys.PublicKey) (bool, error) CurrentScriptHasGroup(*keys.PublicKey) (bool, error) + IsCalledByEntry() bool } type ( @@ -394,8 +394,7 @@ func (c ConditionCalledByEntry) Type() WitnessConditionType { // Match implements the WitnessCondition interface checking whether this condition // matches given context. func (c ConditionCalledByEntry) Match(ctx MatchContext) (bool, error) { - entry := ctx.GetEntryScriptHash() - return entry.Equals(ctx.GetCallingScriptHash()) || entry.Equals(ctx.GetCurrentScriptHash()), nil + return ctx.IsCalledByEntry(), nil } // EncodeBinary implements the WitnessCondition interface allowing to serialize condition. diff --git a/pkg/core/transaction/witness_condition_test.go b/pkg/core/transaction/witness_condition_test.go index e8ac6d4fc..2a4f56244 100644 --- a/pkg/core/transaction/witness_condition_test.go +++ b/pkg/core/transaction/witness_condition_test.go @@ -170,6 +170,9 @@ func (t *TestMC) GetCurrentScriptHash() util.Uint160 { func (t *TestMC) GetEntryScriptHash() util.Uint160 { return t.entry } +func (t *TestMC) IsCalledByEntry() bool { + return t.entry.Equals(t.calling) || t.calling.Equals(util.Uint160{}) +} func (t *TestMC) CallingScriptHasGroup(k *keys.PublicKey) (bool, error) { res, err := t.CurrentScriptHasGroup(k) return !res, err // To differentiate from current we invert the logic value. diff --git a/pkg/core/transaction/witness_scope.go b/pkg/core/transaction/witness_scope.go index 5daf15eaa..aa67eb33a 100644 --- a/pkg/core/transaction/witness_scope.go +++ b/pkg/core/transaction/witness_scope.go @@ -13,7 +13,7 @@ type WitnessScope byte const ( // None specifies that no contract was witnessed. Only sign the transaction. None WitnessScope = 0 - // CalledByEntry means that this condition must hold: EntryScriptHash == CallingScriptHash. + // CalledByEntry witness is only valid in entry script and ones directly called by it. // No params is needed, as the witness/permission/signature given on first invocation will // automatically expire if entering deeper internal invokes. This can be default safe // choice for native NEO/GAS (previously used on Neo 2 as "attach" mode). diff --git a/pkg/vm/context.go b/pkg/vm/context.go index aaf74b01e..31c4280f7 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -37,6 +37,9 @@ type scriptContext struct { // Caller's contract script hash. callingScriptHash util.Uint160 + // Caller's scriptContext, if not entry. + callingContext *scriptContext + // Call flags this context was created with. callFlag callflag.CallFlag @@ -326,6 +329,12 @@ func (v *VM) getContextScriptHash(n int) util.Uint160 { return ctx.ScriptHash() } +// IsCalledByEntry checks parent script contexts and return true if the current one +// is an entry script (the first loaded into the VM) or one called by it. +func (c *Context) IsCalledByEntry() bool { + return c.sc.callingContext == nil || c.sc.callingContext.callingContext == nil +} + // PushContextScriptHash pushes the script hash of the // invocation stack element number n to the evaluation stack. func (v *VM) PushContextScriptHash(n int) error { diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 47c68a772..378a898b2 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -340,15 +340,18 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint if rvcount != -1 || v.estack.Len() != 0 { v.estack = subStack(v.estack) } + parent := v.Context() ctx.sc.estack = v.estack initStack(&ctx.tryStack, "exception", nil) ctx.sc.callFlag = f ctx.sc.scriptHash = hash ctx.sc.callingScriptHash = caller + if parent != nil { + ctx.sc.callingContext = parent.sc + } ctx.sc.NEF = exe if v.invTree != nil { curTree := v.invTree - parent := v.Context() if parent != nil { curTree = parent.sc.invTree } From e8d2277fe56201a00d19e4cfae807f1d6032d0bb Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Thu, 4 Aug 2022 18:17:32 +0300 Subject: [PATCH 4/4] contract/vm: only push NULL after call in dynamic contexts And determine the need for Null dynamically. For some reason the only dynamic context is Contract.Call. CALLT is not dynamic and neither is a call from native contract, go figure... --- pkg/core/interop/contract/call.go | 20 +++++++++++++------- pkg/vm/context.go | 7 ++++++- pkg/vm/vm.go | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 279bbb5e3..aa5989c6d 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -67,11 +68,11 @@ func Call(ic *interop.Context) error { return fmt.Errorf("method not found: %s/%d", method, len(args)) } hasReturn := md.ReturnType != smartcontract.VoidType - return callInternal(ic, cs, method, fs, hasReturn, args, !hasReturn) + return callInternal(ic, cs, method, fs, hasReturn, args, true) } func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, - hasReturn bool, args []stackitem.Item, pushNullOnUnloading bool) error { + hasReturn bool, args []stackitem.Item, isDynamic bool) error { md := cs.Manifest.ABI.GetMethod(name, len(args)) if md.Safe { f &^= (callflag.WriteStates | callflag.AllowNotify) @@ -83,12 +84,12 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl } } } - return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, pushNullOnUnloading, false) + return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, isDynamic, false) } // callExFromNative calls a contract with flags using the provided calling hash. func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, - name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, pushNullOnUnloading bool, callFromNative bool) error { + name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, isDynamic bool, callFromNative bool) error { for _, nc := range ic.Natives { if nc.Metadata().Name == nativenames.Policy { var pch = nc.(policyChecker) @@ -123,7 +124,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra if wrapped { ic.DAO = ic.DAO.GetPrivate() } - onUnload := func(commit bool) error { + onUnload := func(ctx *vm.Context, commit bool) error { if wrapped { if commit { _, err := ic.DAO.Persist() @@ -135,8 +136,13 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra } ic.DAO = baseDAO } - if pushNullOnUnloading && commit { - ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack. + if isDynamic && commit { + eLen := ctx.Estack().Len() + if eLen == 0 && ctx.NumOfReturnVals() == 0 { // No return value and none expected. + ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack. + } else if eLen > 1 { // 1 or -1 (all) retrun values expected, but only one can be returned. + return errors.New("multiple return values in a cross-contract call") + } // All other rvcount/stack length mismatches are checked by the VM. } if callFromNative && !commit { return fmt.Errorf("unhandled exception") diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 31c4280f7..6c086249f 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -73,7 +73,7 @@ type Context struct { } // ContextUnloadCallback is a callback method used on context unloading from istack. -type ContextUnloadCallback func(commit bool) error +type ContextUnloadCallback func(ctx *Context, commit bool) error var errNoInstParam = errors.New("failed to read instruction parameter") @@ -239,6 +239,11 @@ func (c *Context) GetNEF() *nef.File { return c.sc.NEF } +// NumOfReturnVals returns the number of return values expected from this context. +func (c *Context) NumOfReturnVals() int { + return c.retCount +} + // Value implements the stackitem.Item interface. func (c *Context) Value() interface{} { return c diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 378a898b2..44e3f50d5 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1634,7 +1634,7 @@ func (v *VM) unloadContext(ctx *Context) { ctx.sc.static.ClearRefs(&v.refs) } if ctx.sc.onUnload != nil { - err := ctx.sc.onUnload(v.uncaughtException == nil) + err := ctx.sc.onUnload(ctx, v.uncaughtException == nil) if err != nil { errMessage := fmt.Sprintf("context unload callback failed: %s", err) if v.uncaughtException != nil {