From 13f5fdbe8a2046be9eb2a0ded3f0b739b74ca304 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Thu, 4 Aug 2022 16:15:51 +0300 Subject: [PATCH] vm: extract shared parts of the Context Local calls reuse them, cross-contract calls create new ones. This allows to avoid some allocations and use a little less memory. --- pkg/core/interop/contract/call.go | 2 +- pkg/core/interop/runtime/engine.go | 2 +- pkg/vm/context.go | 108 ++++++++++++++++------------- pkg/vm/interop.go | 5 +- pkg/vm/json_test.go | 12 ++-- pkg/vm/opcodebench_test.go | 2 +- pkg/vm/vm.go | 89 +++++++++++------------- 7 files changed, 112 insertions(+), 108 deletions(-) diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 3f47fd5ca..279bbb5e3 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -27,7 +27,7 @@ func LoadToken(ic *interop.Context, id int32) error { if !ctx.GetCallFlags().Has(callflag.ReadStates | callflag.AllowCall) { return errors.New("invalid call flags") } - tok := ctx.NEF.Tokens[id] + tok := ctx.GetNEF().Tokens[id] if int(tok.ParamCount) > ctx.Estack().Len() { return errors.New("stack is too small") } diff --git a/pkg/core/interop/runtime/engine.go b/pkg/core/interop/runtime/engine.go index 81dae7a2c..73654db19 100644 --- a/pkg/core/interop/runtime/engine.go +++ b/pkg/core/interop/runtime/engine.go @@ -73,7 +73,7 @@ func Notify(ic *interop.Context) error { if len(name) > MaxEventNameLen { return fmt.Errorf("event name must be less than %d", MaxEventNameLen) } - if ic.VM.Context().NEF == nil { + if !ic.VM.Context().IsDeployed() { return errors.New("notifications are not allowed in dynamic scripts") } diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 955554482..aaf74b01e 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -16,14 +16,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) -// Context represents the current execution context of the VM. -type Context struct { - // Instruction pointer. - ip int - - // The next instruction pointer. - nextip int - +// scriptContext is a part of the Context that is shared between multiple Contexts, +// it's created when a new script is loaded into the VM while regular +// CALL/CALLL/CALLA internal invocations reuse it. +type scriptContext struct { // The raw program script. prog []byte @@ -33,12 +29,7 @@ type Context struct { // Evaluation stack pointer. estack *Stack - static *slot - local slot - arguments slot - - // Exception context stack. - tryStack Stack + static slot // Script hash of the prog. scriptHash util.Uint160 @@ -49,17 +40,35 @@ type Context struct { // Call flags this context was created with. callFlag callflag.CallFlag - // retCount specifies the number of return values. - retCount int // NEF represents a NEF file for the current contract. NEF *nef.File - // invTree is an invocation tree (or branch of it) for this context. + // invTree is an invocation tree (or a branch of it) for this context. invTree *invocations.Tree // onUnload is a callback that should be called after current context unloading // if no exception occurs. onUnload ContextUnloadCallback } +// Context represents the current execution context of the VM. +type Context struct { + // Instruction pointer. + ip int + + // The next instruction pointer. + nextip int + + sc *scriptContext + + local slot + arguments slot + + // Exception context stack. + tryStack Stack + + // retCount specifies the number of return values. + retCount int +} + // ContextUnloadCallback is a callback method used on context unloading from istack. type ContextUnloadCallback func(commit bool) error @@ -74,7 +83,9 @@ func NewContext(b []byte) *Context { // return value count and initial position in script. func NewContextWithParams(b []byte, rvcount int, pos int) *Context { return &Context{ - prog: b, + sc: &scriptContext{ + prog: b, + }, retCount: rvcount, nextip: pos, } @@ -82,7 +93,7 @@ func NewContextWithParams(b []byte, rvcount int, pos int) *Context { // Estack returns the evaluation stack of c. func (c *Context) Estack() *Stack { - return c.estack + return c.sc.estack } // NextIP returns the next instruction pointer. @@ -92,7 +103,7 @@ func (c *Context) NextIP() int { // Jump unconditionally moves the next instruction pointer to the specified location. func (c *Context) Jump(pos int) { - if pos < 0 || pos >= len(c.prog) { + if pos < 0 || pos >= len(c.sc.prog) { panic("instruction offset is out of range") } c.nextip = pos @@ -105,11 +116,12 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { var err error c.ip = c.nextip - if c.ip >= len(c.prog) { + prog := c.sc.prog + if c.ip >= len(prog) { return opcode.RET, nil, nil } - var instrbyte = c.prog[c.ip] + var instrbyte = prog[c.ip] instr := opcode.Opcode(instrbyte) if !opcode.IsValid(instr) { return instr, nil, fmt.Errorf("incorrect opcode %s", instr.String()) @@ -119,24 +131,24 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { var numtoread int switch instr { case opcode.PUSHDATA1: - if c.nextip >= len(c.prog) { + if c.nextip >= len(prog) { err = errNoInstParam } else { - numtoread = int(c.prog[c.nextip]) + numtoread = int(prog[c.nextip]) c.nextip++ } case opcode.PUSHDATA2: - if c.nextip+1 >= len(c.prog) { + if c.nextip+1 >= len(prog) { err = errNoInstParam } else { - numtoread = int(binary.LittleEndian.Uint16(c.prog[c.nextip : c.nextip+2])) + numtoread = int(binary.LittleEndian.Uint16(prog[c.nextip : c.nextip+2])) c.nextip += 2 } case opcode.PUSHDATA4: - if c.nextip+3 >= len(c.prog) { + if c.nextip+3 >= len(prog) { err = errNoInstParam } else { - var n = binary.LittleEndian.Uint32(c.prog[c.nextip : c.nextip+4]) + var n = binary.LittleEndian.Uint32(prog[c.nextip : c.nextip+4]) if n > stackitem.MaxSize { return instr, nil, errors.New("parameter is too big") } @@ -166,13 +178,13 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { return instr, nil, nil } } - if c.nextip+numtoread-1 >= len(c.prog) { + if c.nextip+numtoread-1 >= len(prog) { err = errNoInstParam } if err != nil { return instr, nil, err } - parameter := c.prog[c.nextip : c.nextip+numtoread] + parameter := prog[c.nextip : c.nextip+numtoread] c.nextip += numtoread return instr, parameter, nil } @@ -184,46 +196,44 @@ func (c *Context) IP() int { // LenInstr returns the number of instructions loaded. func (c *Context) LenInstr() int { - return len(c.prog) + return len(c.sc.prog) } // CurrInstr returns the current instruction and opcode. func (c *Context) CurrInstr() (int, opcode.Opcode) { - return c.ip, opcode.Opcode(c.prog[c.ip]) + return c.ip, opcode.Opcode(c.sc.prog[c.ip]) } // NextInstr returns the next instruction and opcode. func (c *Context) NextInstr() (int, opcode.Opcode) { op := opcode.RET - if c.nextip < len(c.prog) { - op = opcode.Opcode(c.prog[c.nextip]) + if c.nextip < len(c.sc.prog) { + op = opcode.Opcode(c.sc.prog[c.nextip]) } return c.nextip, op } -// Copy returns an new exact copy of c. -func (c *Context) Copy() *Context { - ctx := new(Context) - *ctx = *c - return ctx -} - // GetCallFlags returns the calling flags which the context was created with. func (c *Context) GetCallFlags() callflag.CallFlag { - return c.callFlag + return c.sc.callFlag } // Program returns the loaded program. func (c *Context) Program() []byte { - return c.prog + return c.sc.prog } // ScriptHash returns a hash of the script in the current context. func (c *Context) ScriptHash() util.Uint160 { - if c.scriptHash.Equals(util.Uint160{}) { - c.scriptHash = hash.Hash160(c.prog) + if c.sc.scriptHash.Equals(util.Uint160{}) { + c.sc.scriptHash = hash.Hash160(c.sc.prog) } - return c.scriptHash + return c.sc.scriptHash +} + +// GetNEF returns NEF structure used by this context if it's present. +func (c *Context) GetNEF() *nef.File { + return c.sc.NEF } // Value implements the stackitem.Item interface. @@ -263,7 +273,7 @@ func (c *Context) Equals(s stackitem.Item) bool { } func (c *Context) atBreakPoint() bool { - for _, n := range c.breakPoints { + for _, n := range c.sc.breakPoints { if n == c.nextip { return true } @@ -277,12 +287,12 @@ func (c *Context) String() string { // IsDeployed returns whether this context contains a deployed contract. func (c *Context) IsDeployed() bool { - return c.NEF != nil + return c.sc.NEF != nil } // DumpStaticSlot returns json formatted representation of the given slot. func (c *Context) DumpStaticSlot() string { - return dumpSlot(c.static) + return dumpSlot(&c.sc.static) } // DumpLocalSlot returns json formatted representation of the given slot. diff --git a/pkg/vm/interop.go b/pkg/vm/interop.go index 4885f255e..27156b31f 100644 --- a/pkg/vm/interop.go +++ b/pkg/vm/interop.go @@ -36,8 +36,9 @@ func defaultSyscallHandler(v *VM, id uint32) error { return errors.New("syscall not found") } d := defaultVMInterops[n] - if !v.Context().callFlag.Has(d.RequiredFlags) { - return fmt.Errorf("missing call flags: %05b vs %05b", v.Context().callFlag, d.RequiredFlags) + ctxFlag := v.Context().sc.callFlag + if !ctxFlag.Has(d.RequiredFlags) { + return fmt.Errorf("missing call flags: %05b vs %05b", ctxFlag, d.RequiredFlags) } return d.Func(v) } diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index 86a68540d..8256a6840 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -115,7 +115,7 @@ func testSyscallHandler(v *VM, id uint32) error { case 0x77777777: v.Estack().PushVal(stackitem.NewInterop(new(int))) case 0x66666666: - if !v.Context().callFlag.Has(callflag.ReadOnly) { + if !v.Context().sc.callFlag.Has(callflag.ReadOnly) { return errors.New("invalid call flags") } v.Estack().PushVal(stackitem.NewInterop(new(int))) @@ -167,14 +167,14 @@ func testFile(t *testing.T, filename string) { if len(result.InvocationStack) > 0 { for i, s := range result.InvocationStack { ctx := vm.istack.Peek(i).Value().(*Context) - if ctx.nextip < len(ctx.prog) { + if ctx.nextip < len(ctx.sc.prog) { require.Equal(t, s.InstructionPointer, ctx.nextip) op, err := opcode.FromString(s.Instruction) require.NoError(t, err) - require.Equal(t, op, opcode.Opcode(ctx.prog[ctx.nextip])) + require.Equal(t, op, opcode.Opcode(ctx.sc.prog[ctx.nextip])) } compareStacks(t, s.EStack, vm.estack) - compareSlots(t, s.StaticFields, ctx.static) + compareSlots(t, s.StaticFields, ctx.sc.static) } } @@ -240,8 +240,8 @@ func compareStacks(t *testing.T, expected []vmUTStackItem, actual *Stack) { compareItemArrays(t, expected, actual.Len(), func(i int) stackitem.Item { return actual.Peek(i).Item() }) } -func compareSlots(t *testing.T, expected []vmUTStackItem, actual *slot) { - if (actual == nil || *actual == nil) && len(expected) == 0 { +func compareSlots(t *testing.T, expected []vmUTStackItem, actual slot) { + if actual == nil && len(expected) == 0 { return } require.NotNil(t, actual) diff --git a/pkg/vm/opcodebench_test.go b/pkg/vm/opcodebench_test.go index 3ff86c09a..1494dbd9a 100644 --- a/pkg/vm/opcodebench_test.go +++ b/pkg/vm/opcodebench_test.go @@ -56,7 +56,7 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int, return nil } if sslot != 0 { - v.Context().static.init(sslot, &v.refs) + v.Context().sc.static.init(sslot, &v.refs) } if slotloc != 0 && slotarg != 0 { v.Context().local.init(slotloc, &v.refs) diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index aea077e74..47c68a772 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -171,9 +171,7 @@ func (v *VM) PrintOps(out io.Writer) { w := tabwriter.NewWriter(out, 0, 0, 4, ' ', 0) fmt.Fprintln(w, "INDEX\tOPCODE\tPARAMETER") realctx := v.Context() - ctx := realctx.Copy() - ctx.ip = 0 - ctx.nextip = 0 + ctx := &Context{sc: realctx.sc} for { cursor := "" instr, parameter, err := ctx.Next() @@ -228,7 +226,7 @@ func (v *VM) PrintOps(out io.Writer) { } fmt.Fprintf(w, "%d\t%s\t%s%s\n", ctx.ip, instr, desc, cursor) - if ctx.nextip >= len(ctx.prog) { + if ctx.nextip >= len(ctx.sc.prog) { break } } @@ -246,7 +244,7 @@ func getOffsetDesc(ctx *Context, parameter []byte) string { // AddBreakPoint adds a breakpoint to the current context. func (v *VM) AddBreakPoint(n int) { ctx := v.Context() - ctx.breakPoints = append(ctx.breakPoints, n) + ctx.sc.breakPoints = append(ctx.sc.breakPoints, n) } // AddBreakPointRel adds a breakpoint relative to the current @@ -337,31 +335,28 @@ func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160 // It should be used for calling from native contracts. func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160, hash util.Uint160, f callflag.CallFlag, rvcount int, offset int, onContextUnload ContextUnloadCallback) { - var sl slot - v.checkInvocationStackSize() ctx := NewContextWithParams(b, rvcount, offset) if rvcount != -1 || v.estack.Len() != 0 { v.estack = subStack(v.estack) } - ctx.estack = v.estack + ctx.sc.estack = v.estack initStack(&ctx.tryStack, "exception", nil) - ctx.callFlag = f - ctx.static = &sl - ctx.scriptHash = hash - ctx.callingScriptHash = caller - ctx.NEF = exe + ctx.sc.callFlag = f + ctx.sc.scriptHash = hash + ctx.sc.callingScriptHash = caller + ctx.sc.NEF = exe if v.invTree != nil { curTree := v.invTree parent := v.Context() if parent != nil { - curTree = parent.invTree + curTree = parent.sc.invTree } newTree := &invocations.Tree{Current: ctx.ScriptHash()} curTree.Calls = append(curTree.Calls, newTree) - ctx.invTree = newTree + ctx.sc.invTree = newTree } - ctx.onUnload = onContextUnload + ctx.sc.onUnload = onContextUnload v.istack.PushItem(ctx) } @@ -481,7 +476,7 @@ func (v *VM) StepInto() error { return nil } - if ctx != nil && ctx.prog != nil { + if ctx != nil && ctx.sc.prog != nil { op, param, err := ctx.Next() if err != nil { v.state = vmstate.Fault @@ -584,7 +579,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } }() - if v.getPrice != nil && ctx.ip < len(ctx.prog) { + if v.getPrice != nil && ctx.ip < len(ctx.sc.prog) { v.gasConsumed += v.getPrice(op, parameter) if v.GasLimit >= 0 && v.gasConsumed > v.GasLimit { panic("gas limit is exceeded") @@ -610,7 +605,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.PUSHA: n := getJumpOffset(ctx, parameter) - ptr := stackitem.NewPointerWithHash(n, ctx.prog, ctx.ScriptHash()) + ptr := stackitem.NewPointerWithHash(n, ctx.sc.prog, ctx.ScriptHash()) v.estack.PushItem(ptr) case opcode.PUSHNULL: @@ -637,7 +632,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if parameter[0] == 0 { panic("zero argument") } - ctx.static.init(int(parameter[0]), &v.refs) + ctx.sc.static.init(int(parameter[0]), &v.refs) case opcode.INITSLOT: if ctx.local != nil || ctx.arguments != nil { @@ -658,20 +653,20 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } case opcode.LDSFLD0, opcode.LDSFLD1, opcode.LDSFLD2, opcode.LDSFLD3, opcode.LDSFLD4, opcode.LDSFLD5, opcode.LDSFLD6: - item := ctx.static.Get(int(op - opcode.LDSFLD0)) + item := ctx.sc.static.Get(int(op - opcode.LDSFLD0)) v.estack.PushItem(item) case opcode.LDSFLD: - item := ctx.static.Get(int(parameter[0])) + item := ctx.sc.static.Get(int(parameter[0])) v.estack.PushItem(item) case opcode.STSFLD0, opcode.STSFLD1, opcode.STSFLD2, opcode.STSFLD3, opcode.STSFLD4, opcode.STSFLD5, opcode.STSFLD6: item := v.estack.Pop().Item() - ctx.static.Set(int(op-opcode.STSFLD0), item, &v.refs) + ctx.sc.static.Set(int(op-opcode.STSFLD0), item, &v.refs) case opcode.STSFLD: item := v.estack.Pop().Item() - ctx.static.Set(int(parameter[0]), item, &v.refs) + ctx.sc.static.Set(int(parameter[0]), item, &v.refs) case opcode.LDLOC0, opcode.LDLOC1, opcode.LDLOC2, opcode.LDLOC3, opcode.LDLOC4, opcode.LDLOC5, opcode.LDLOC6: item := ctx.local.Get(int(op - opcode.LDLOC0)) @@ -1475,7 +1470,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro break } - newEstack := v.Context().estack + newEstack := v.Context().sc.estack if oldEstack != newEstack { if oldCtx.retCount >= 0 && oldEstack.Len() != oldCtx.retCount { panic(fmt.Errorf("invalid return values count: expected %d, got %d", @@ -1631,17 +1626,19 @@ func (v *VM) unloadContext(ctx *Context) { ctx.arguments.ClearRefs(&v.refs) } currCtx := v.Context() - if ctx.static != nil && (currCtx == nil || ctx.static != currCtx.static) { - ctx.static.ClearRefs(&v.refs) - } - if ctx.onUnload != nil { - err := ctx.onUnload(v.uncaughtException == nil) - if err != nil { - errMessage := fmt.Sprintf("context unload callback failed: %s", err) - if v.uncaughtException != nil { - errMessage = fmt.Sprintf("%s, uncaught exception: %s", errMessage, v.uncaughtException) + if currCtx == nil || ctx.sc != currCtx.sc { + if ctx.sc.static != nil { + ctx.sc.static.ClearRefs(&v.refs) + } + if ctx.sc.onUnload != nil { + err := ctx.sc.onUnload(v.uncaughtException == nil) + if err != nil { + errMessage := fmt.Sprintf("context unload callback failed: %s", err) + if v.uncaughtException != nil { + errMessage = fmt.Sprintf("%s, uncaught exception: %s", errMessage, v.uncaughtException) + } + panic(errors.New(errMessage)) } - panic(errors.New(errMessage)) } } } @@ -1691,17 +1688,13 @@ func (v *VM) Call(offset int) { // package. func (v *VM) call(ctx *Context, offset int) { v.checkInvocationStackSize() - newCtx := ctx.Copy() - newCtx.retCount = -1 - newCtx.local = nil - newCtx.arguments = nil - // If memory for `elems` is reused, we can end up - // with an incorrect exception context state in the caller. - newCtx.tryStack.elems = nil - initStack(&newCtx.tryStack, "exception", nil) - newCtx.NEF = ctx.NEF - // Do not clone unloading callback, new context does not require any actions to perform on unloading. - newCtx.onUnload = nil + newCtx := &Context{ + sc: ctx.sc, + retCount: -1, + tryStack: ctx.tryStack, + } + // New context -> new exception handlers. + newCtx.tryStack.elems = ctx.tryStack.elems[len(ctx.tryStack.elems):] v.istack.PushItem(newCtx) newCtx.Jump(offset) } @@ -1732,7 +1725,7 @@ func calcJumpOffset(ctx *Context, parameter []byte) (int, int, error) { return 0, 0, fmt.Errorf("invalid %s parameter length: %d", curr, l) } offset := ctx.ip + int(rOffset) - if offset < 0 || offset > len(ctx.prog) { + if offset < 0 || offset > len(ctx.sc.prog) { return 0, 0, fmt.Errorf("invalid offset %d ip at %d", offset, ctx.ip) } @@ -1955,7 +1948,7 @@ func bytesToPublicKey(b []byte, curve elliptic.Curve) *keys.PublicKey { // GetCallingScriptHash implements the ScriptHashGetter interface. func (v *VM) GetCallingScriptHash() util.Uint160 { - return v.Context().callingScriptHash + return v.Context().sc.callingScriptHash } // GetEntryScriptHash implements the ScriptHashGetter interface.