diff --git a/pkg/vm/context.go b/pkg/vm/context.go index bf4f061ee..1e7ac9d1a 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -1,7 +1,9 @@ package vm import ( - "encoding/binary" + "errors" + + "github.com/CityOfZion/neo-go/pkg/io" ) // Context represent the current execution context of the VM. @@ -9,6 +11,9 @@ type Context struct { // Instruction pointer. ip int + // The next instruction pointer. + nextip int + // The raw program script. prog []byte @@ -19,19 +24,62 @@ type Context struct { // NewContext return a new Context object. func NewContext(b []byte) *Context { return &Context{ - ip: -1, prog: b, breakPoints: []int{}, } } -// Next return the next instruction to execute. -func (c *Context) Next() Instruction { - c.ip++ +// Next returns the next instruction to execute with its parameter if any. After +// its invocation the instruction pointer points to the instruction being +// returned. +func (c *Context) Next() (Instruction, []byte, error) { + c.ip = c.nextip if c.ip >= len(c.prog) { - return RET + return RET, nil, nil } - return Instruction(c.prog[c.ip]) + r := io.NewBinReaderFromBuf(c.prog[c.ip:]) + + var instrbyte byte + r.ReadLE(&instrbyte) + instr := Instruction(instrbyte) + c.nextip++ + + var numtoread int + switch instr { + case PUSHDATA1, SYSCALL: + var n byte + r.ReadLE(&n) + numtoread = int(n) + c.nextip++ + case PUSHDATA2: + var n uint16 + r.ReadLE(&n) + numtoread = int(n) + c.nextip += 2 + case PUSHDATA4: + var n uint32 + r.ReadLE(&n) + numtoread = int(n) + c.nextip += 4 + case JMP, JMPIF, JMPIFNOT, CALL: + numtoread = 2 + case APPCALL, TAILCALL: + numtoread = 20 + default: + if instr >= PUSHBYTES1 && instr <= PUSHBYTES75 { + numtoread = int(instr) + } else { + // No parameters, can just return. + return instr, nil, nil + } + } + parameter := make([]byte, numtoread) + r.ReadLE(parameter) + if r.Err != nil { + return instr, nil, errors.New("failed to read instruction parameter") + } + c.nextip += numtoread + return instr, parameter, nil } // IP returns the absolute instruction without taking 0 into account. @@ -48,19 +96,14 @@ func (c *Context) LenInstr() int { // CurrInstr returns the current instruction and opcode. func (c *Context) CurrInstr() (int, Instruction) { - if c.ip < 0 { - return c.ip, NOP - } return c.ip, Instruction(c.prog[c.ip]) } // Copy returns an new exact copy of c. func (c *Context) Copy() *Context { - return &Context{ - ip: c.ip, - prog: c.prog, - breakPoints: c.breakPoints, - } + ctx := new(Context) + *ctx = *c + return ctx } // Program returns the loaded program. @@ -85,44 +128,3 @@ func (c *Context) atBreakPoint() bool { func (c *Context) String() string { return "execution context" } - -func (c *Context) readUint32() uint32 { - start, end := c.IP(), c.IP()+4 - if end > len(c.prog) { - panic("failed to read uint32 parameter") - } - val := binary.LittleEndian.Uint32(c.prog[start:end]) - c.ip += 4 - return val -} - -func (c *Context) readUint16() uint16 { - start, end := c.IP(), c.IP()+2 - if end > len(c.prog) { - panic("failed to read uint16 parameter") - } - val := binary.LittleEndian.Uint16(c.prog[start:end]) - c.ip += 2 - return val -} - -func (c *Context) readByte() byte { - return c.readBytes(1)[0] -} - -func (c *Context) readBytes(n int) []byte { - start, end := c.IP(), c.IP()+n - if end > len(c.prog) { - return nil - } - - out := make([]byte, n) - copy(out, c.prog[start:end]) - c.ip += n - return out -} - -func (c *Context) readVarBytes() []byte { - n := c.readByte() - return c.readBytes(int(n)) -} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 371688b82..2a698febb 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -2,6 +2,7 @@ package vm import ( "crypto/sha1" + "encoding/binary" "fmt" "io/ioutil" "log" @@ -224,6 +225,11 @@ func (v *VM) Run() { v.state = noneState for { + // check for breakpoint before executing the next instruction + ctx := v.Context() + if ctx != nil && ctx.atBreakPoint() { + v.state |= breakState + } switch { case v.state.HasFlag(faultState): fmt.Println("FAULT") @@ -247,14 +253,13 @@ func (v *VM) Run() { // Step 1 instruction in the program. func (v *VM) Step() { ctx := v.Context() - op := ctx.Next() - v.execute(ctx, op) - - // re-peek the context as it could been changed during execution. - cctx := v.Context() - if cctx != nil && cctx.atBreakPoint() { - v.state = breakState + op, param, err := ctx.Next() + if err != nil { + log.Printf("error encountered at instruction %d (%s)", ctx.ip, op) + log.Println(err) + v.state = faultState } + v.execute(ctx, op, param) } // HasFailed returns whether VM is in the failed state now. Usually used to @@ -275,7 +280,7 @@ func (v *VM) SetScriptGetter(gs func(util.Uint160) []byte) { } // execute performs an instruction cycle in the VM. Acting on the instruction (opcode). -func (v *VM) execute(ctx *Context, op Instruction) { +func (v *VM) execute(ctx *Context, op Instruction, parameter []byte) { // Instead of polluting the whole VM logic with error handling, we will recover // each panic at a central point, putting the VM in a fault state. defer func() { @@ -287,11 +292,7 @@ func (v *VM) execute(ctx *Context, op Instruction) { }() if op >= PUSHBYTES1 && op <= PUSHBYTES75 { - b := ctx.readBytes(int(op)) - if b == nil { - panic("failed to read instruction parameter") - } - v.estack.PushVal(b) + v.estack.PushVal(parameter) return } @@ -305,29 +306,8 @@ func (v *VM) execute(ctx *Context, op Instruction) { case PUSH0: v.estack.PushVal([]byte{}) - case PUSHDATA1: - n := ctx.readByte() - b := ctx.readBytes(int(n)) - if b == nil { - panic("failed to read instruction parameter") - } - v.estack.PushVal(b) - - case PUSHDATA2: - n := ctx.readUint16() - b := ctx.readBytes(int(n)) - if b == nil { - panic("failed to read instruction parameter") - } - v.estack.PushVal(b) - - case PUSHDATA4: - n := ctx.readUint32() - b := ctx.readBytes(int(n)) - if b == nil { - panic("failed to read instruction parameter") - } - v.estack.PushVal(b) + case PUSHDATA1, PUSHDATA2, PUSHDATA4: + v.estack.PushVal(parameter) // Stack operations. case TOALTSTACK: @@ -843,8 +823,8 @@ func (v *VM) execute(ctx *Context, op Instruction) { case JMP, JMPIF, JMPIFNOT: var ( - rOffset = int16(ctx.readUint16()) - offset = ctx.ip + int(rOffset) - 3 // sizeOf(int16 + uint8) + rOffset = int16(binary.LittleEndian.Uint16(parameter)) + offset = ctx.ip + int(rOffset) ) if offset < 0 || offset > len(ctx.prog) { panic(fmt.Sprintf("JMP: invalid offset %d ip at %d", offset, ctx.ip)) @@ -857,19 +837,17 @@ func (v *VM) execute(ctx *Context, op Instruction) { } } if cond { - ctx.ip = offset + ctx.nextip = offset } case CALL: v.istack.PushVal(ctx.Copy()) - ctx.ip += 2 - v.execute(v.Context(), JMP) + v.execute(v.Context(), JMP, parameter) case SYSCALL: - api := ctx.readVarBytes() - ifunc, ok := v.interop[string(api)] + ifunc, ok := v.interop[string(parameter)] if !ok { - panic(fmt.Sprintf("interop hook (%s) not registered", api)) + panic(fmt.Sprintf("interop hook (%q) not registered", parameter)) } if err := ifunc.Func(v); err != nil { panic(fmt.Sprintf("failed to invoke syscall: %s", err)) @@ -880,7 +858,7 @@ func (v *VM) execute(ctx *Context, op Instruction) { panic("no getScript callback is set up") } - hash, err := util.Uint160DecodeBytes(ctx.readBytes(20)) + hash, err := util.Uint160DecodeBytes(parameter) if err != nil { panic(err) } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 7e11eb7c5..e14f08dad 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -54,7 +54,7 @@ func TestPushBytes1to75(t *testing.T) { assert.IsType(t, elem.Bytes(), b) assert.Equal(t, 0, vm.estack.Len()) - vm.execute(nil, RET) + vm.execute(nil, RET, nil) assert.Equal(t, 0, vm.astack.Len()) assert.Equal(t, 0, vm.istack.Len())