vm: completely separate instruction read and execution phases

Make Context.Next() return both opcode and instruction parameter if any. This
simplifies some code and needed to deal with #295.
This commit is contained in:
Roman Khimov 2019-10-03 16:54:14 +03:00
parent 1bf232ad50
commit 53a3b18652
3 changed files with 82 additions and 102 deletions

View file

@ -1,7 +1,9 @@
package vm package vm
import ( import (
"encoding/binary" "errors"
"github.com/CityOfZion/neo-go/pkg/io"
) )
// Context represent the current execution context of the VM. // Context represent the current execution context of the VM.
@ -9,6 +11,9 @@ type Context struct {
// Instruction pointer. // Instruction pointer.
ip int ip int
// The next instruction pointer.
nextip int
// The raw program script. // The raw program script.
prog []byte prog []byte
@ -19,19 +24,62 @@ type Context struct {
// NewContext return a new Context object. // NewContext return a new Context object.
func NewContext(b []byte) *Context { func NewContext(b []byte) *Context {
return &Context{ return &Context{
ip: -1,
prog: b, prog: b,
breakPoints: []int{}, breakPoints: []int{},
} }
} }
// Next return the next instruction to execute. // Next returns the next instruction to execute with its parameter if any. After
func (c *Context) Next() Instruction { // its invocation the instruction pointer points to the instruction being
c.ip++ // returned.
func (c *Context) Next() (Instruction, []byte, error) {
c.ip = c.nextip
if c.ip >= len(c.prog) { if c.ip >= len(c.prog) {
return RET return RET, nil, nil
} }
return Instruction(c.prog[c.ip]) r := io.NewBinReaderFromBuf(c.prog[c.ip:])
var instrbyte byte
r.ReadLE(&instrbyte)
instr := Instruction(instrbyte)
c.nextip++
var numtoread int
switch instr {
case PUSHDATA1, SYSCALL:
var n byte
r.ReadLE(&n)
numtoread = int(n)
c.nextip++
case PUSHDATA2:
var n uint16
r.ReadLE(&n)
numtoread = int(n)
c.nextip += 2
case PUSHDATA4:
var n uint32
r.ReadLE(&n)
numtoread = int(n)
c.nextip += 4
case JMP, JMPIF, JMPIFNOT, CALL:
numtoread = 2
case APPCALL, TAILCALL:
numtoread = 20
default:
if instr >= PUSHBYTES1 && instr <= PUSHBYTES75 {
numtoread = int(instr)
} else {
// No parameters, can just return.
return instr, nil, nil
}
}
parameter := make([]byte, numtoread)
r.ReadLE(parameter)
if r.Err != nil {
return instr, nil, errors.New("failed to read instruction parameter")
}
c.nextip += numtoread
return instr, parameter, nil
} }
// IP returns the absolute instruction without taking 0 into account. // IP returns the absolute instruction without taking 0 into account.
@ -48,19 +96,14 @@ func (c *Context) LenInstr() int {
// CurrInstr returns the current instruction and opcode. // CurrInstr returns the current instruction and opcode.
func (c *Context) CurrInstr() (int, Instruction) { func (c *Context) CurrInstr() (int, Instruction) {
if c.ip < 0 {
return c.ip, NOP
}
return c.ip, Instruction(c.prog[c.ip]) return c.ip, Instruction(c.prog[c.ip])
} }
// Copy returns an new exact copy of c. // Copy returns an new exact copy of c.
func (c *Context) Copy() *Context { func (c *Context) Copy() *Context {
return &Context{ ctx := new(Context)
ip: c.ip, *ctx = *c
prog: c.prog, return ctx
breakPoints: c.breakPoints,
}
} }
// Program returns the loaded program. // Program returns the loaded program.
@ -85,44 +128,3 @@ func (c *Context) atBreakPoint() bool {
func (c *Context) String() string { func (c *Context) String() string {
return "execution context" return "execution context"
} }
func (c *Context) readUint32() uint32 {
start, end := c.IP(), c.IP()+4
if end > len(c.prog) {
panic("failed to read uint32 parameter")
}
val := binary.LittleEndian.Uint32(c.prog[start:end])
c.ip += 4
return val
}
func (c *Context) readUint16() uint16 {
start, end := c.IP(), c.IP()+2
if end > len(c.prog) {
panic("failed to read uint16 parameter")
}
val := binary.LittleEndian.Uint16(c.prog[start:end])
c.ip += 2
return val
}
func (c *Context) readByte() byte {
return c.readBytes(1)[0]
}
func (c *Context) readBytes(n int) []byte {
start, end := c.IP(), c.IP()+n
if end > len(c.prog) {
return nil
}
out := make([]byte, n)
copy(out, c.prog[start:end])
c.ip += n
return out
}
func (c *Context) readVarBytes() []byte {
n := c.readByte()
return c.readBytes(int(n))
}

View file

@ -2,6 +2,7 @@ package vm
import ( import (
"crypto/sha1" "crypto/sha1"
"encoding/binary"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -224,6 +225,11 @@ func (v *VM) Run() {
v.state = noneState v.state = noneState
for { for {
// check for breakpoint before executing the next instruction
ctx := v.Context()
if ctx != nil && ctx.atBreakPoint() {
v.state |= breakState
}
switch { switch {
case v.state.HasFlag(faultState): case v.state.HasFlag(faultState):
fmt.Println("FAULT") fmt.Println("FAULT")
@ -247,14 +253,13 @@ func (v *VM) Run() {
// Step 1 instruction in the program. // Step 1 instruction in the program.
func (v *VM) Step() { func (v *VM) Step() {
ctx := v.Context() ctx := v.Context()
op := ctx.Next() op, param, err := ctx.Next()
v.execute(ctx, op) if err != nil {
log.Printf("error encountered at instruction %d (%s)", ctx.ip, op)
// re-peek the context as it could been changed during execution. log.Println(err)
cctx := v.Context() v.state = faultState
if cctx != nil && cctx.atBreakPoint() {
v.state = breakState
} }
v.execute(ctx, op, param)
} }
// HasFailed returns whether VM is in the failed state now. Usually used to // HasFailed returns whether VM is in the failed state now. Usually used to
@ -275,7 +280,7 @@ func (v *VM) SetScriptGetter(gs func(util.Uint160) []byte) {
} }
// execute performs an instruction cycle in the VM. Acting on the instruction (opcode). // execute performs an instruction cycle in the VM. Acting on the instruction (opcode).
func (v *VM) execute(ctx *Context, op Instruction) { func (v *VM) execute(ctx *Context, op Instruction, parameter []byte) {
// Instead of polluting the whole VM logic with error handling, we will recover // Instead of polluting the whole VM logic with error handling, we will recover
// each panic at a central point, putting the VM in a fault state. // each panic at a central point, putting the VM in a fault state.
defer func() { defer func() {
@ -287,11 +292,7 @@ func (v *VM) execute(ctx *Context, op Instruction) {
}() }()
if op >= PUSHBYTES1 && op <= PUSHBYTES75 { if op >= PUSHBYTES1 && op <= PUSHBYTES75 {
b := ctx.readBytes(int(op)) v.estack.PushVal(parameter)
if b == nil {
panic("failed to read instruction parameter")
}
v.estack.PushVal(b)
return return
} }
@ -305,29 +306,8 @@ func (v *VM) execute(ctx *Context, op Instruction) {
case PUSH0: case PUSH0:
v.estack.PushVal([]byte{}) v.estack.PushVal([]byte{})
case PUSHDATA1: case PUSHDATA1, PUSHDATA2, PUSHDATA4:
n := ctx.readByte() v.estack.PushVal(parameter)
b := ctx.readBytes(int(n))
if b == nil {
panic("failed to read instruction parameter")
}
v.estack.PushVal(b)
case PUSHDATA2:
n := ctx.readUint16()
b := ctx.readBytes(int(n))
if b == nil {
panic("failed to read instruction parameter")
}
v.estack.PushVal(b)
case PUSHDATA4:
n := ctx.readUint32()
b := ctx.readBytes(int(n))
if b == nil {
panic("failed to read instruction parameter")
}
v.estack.PushVal(b)
// Stack operations. // Stack operations.
case TOALTSTACK: case TOALTSTACK:
@ -843,8 +823,8 @@ func (v *VM) execute(ctx *Context, op Instruction) {
case JMP, JMPIF, JMPIFNOT: case JMP, JMPIF, JMPIFNOT:
var ( var (
rOffset = int16(ctx.readUint16()) rOffset = int16(binary.LittleEndian.Uint16(parameter))
offset = ctx.ip + int(rOffset) - 3 // sizeOf(int16 + uint8) offset = ctx.ip + int(rOffset)
) )
if offset < 0 || offset > len(ctx.prog) { if offset < 0 || offset > len(ctx.prog) {
panic(fmt.Sprintf("JMP: invalid offset %d ip at %d", offset, ctx.ip)) panic(fmt.Sprintf("JMP: invalid offset %d ip at %d", offset, ctx.ip))
@ -857,19 +837,17 @@ func (v *VM) execute(ctx *Context, op Instruction) {
} }
} }
if cond { if cond {
ctx.ip = offset ctx.nextip = offset
} }
case CALL: case CALL:
v.istack.PushVal(ctx.Copy()) v.istack.PushVal(ctx.Copy())
ctx.ip += 2 v.execute(v.Context(), JMP, parameter)
v.execute(v.Context(), JMP)
case SYSCALL: case SYSCALL:
api := ctx.readVarBytes() ifunc, ok := v.interop[string(parameter)]
ifunc, ok := v.interop[string(api)]
if !ok { if !ok {
panic(fmt.Sprintf("interop hook (%s) not registered", api)) panic(fmt.Sprintf("interop hook (%q) not registered", parameter))
} }
if err := ifunc.Func(v); err != nil { if err := ifunc.Func(v); err != nil {
panic(fmt.Sprintf("failed to invoke syscall: %s", err)) panic(fmt.Sprintf("failed to invoke syscall: %s", err))
@ -880,7 +858,7 @@ func (v *VM) execute(ctx *Context, op Instruction) {
panic("no getScript callback is set up") panic("no getScript callback is set up")
} }
hash, err := util.Uint160DecodeBytes(ctx.readBytes(20)) hash, err := util.Uint160DecodeBytes(parameter)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -54,7 +54,7 @@ func TestPushBytes1to75(t *testing.T) {
assert.IsType(t, elem.Bytes(), b) assert.IsType(t, elem.Bytes(), b)
assert.Equal(t, 0, vm.estack.Len()) assert.Equal(t, 0, vm.estack.Len())
vm.execute(nil, RET) vm.execute(nil, RET, nil)
assert.Equal(t, 0, vm.astack.Len()) assert.Equal(t, 0, vm.astack.Len())
assert.Equal(t, 0, vm.istack.Len()) assert.Equal(t, 0, vm.istack.Len())