Merge pull request #417 from nspcc-dev/various-verification-fixes2

Various verification fixes part 2, which focuses on VM improvements
necessary to make tx verification work. It's mostly related to interop
functionality, but doesn't add interops at the moment. Fixes #295 along the way.
This commit is contained in:
Roman Khimov 2019-10-04 16:17:24 +03:00 committed by GitHub
commit aab2f9a837
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 436 additions and 183 deletions

View file

@ -1,6 +1,10 @@
package vm package vm
import ( import (
"errors"
"io/ioutil"
"github.com/CityOfZion/neo-go/pkg/vm"
vmcli "github.com/CityOfZion/neo-go/pkg/vm/cli" vmcli "github.com/CityOfZion/neo-go/pkg/vm/cli"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@ -14,6 +18,19 @@ func NewCommand() cli.Command {
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.BoolFlag{Name: "debug, d"}, cli.BoolFlag{Name: "debug, d"},
}, },
Subcommands: []cli.Command{
{
Name: "inspect",
Usage: "dump instructions of the avm file given",
Action: inspect,
Flags: []cli.Flag{
cli.StringFlag{
Name: "in, i",
Usage: "input file of the program (AVM)",
},
},
},
},
} }
} }
@ -21,3 +38,18 @@ func startVMPrompt(ctx *cli.Context) error {
p := vmcli.New() p := vmcli.New()
return p.Run() return p.Run()
} }
func inspect(ctx *cli.Context) error {
avm := ctx.String("in")
if len(avm) == 0 {
return cli.NewExitError(errors.New("no input file given"), 1)
}
b, err := ioutil.ReadFile(avm)
if err != nil {
return cli.NewExitError(err, 1)
}
v := vm.New(0)
v.LoadScript(b)
v.PrintOps()
return nil
}

View file

@ -313,6 +313,7 @@ func (bc *Blockchain) storeBlock(block *Block) error {
spentCoins = make(SpentCoins) spentCoins = make(SpentCoins)
accounts = make(Accounts) accounts = make(Accounts)
assets = make(Assets) assets = make(Assets)
contracts = make(Contracts)
) )
if err := storeAsBlock(batch, block, 0); err != nil { if err := storeAsBlock(batch, block, 0); err != nil {
@ -399,7 +400,7 @@ func (bc *Blockchain) storeBlock(block *Block) error {
Email: t.Email, Email: t.Email,
Description: t.Description, Description: t.Description,
} }
_ = contract contracts[contract.ScriptHash()] = contract
case *transaction.InvocationTX: case *transaction.InvocationTX:
} }
@ -418,6 +419,9 @@ func (bc *Blockchain) storeBlock(block *Block) error {
if err := assets.commit(batch); err != nil { if err := assets.commit(batch); err != nil {
return err return err
} }
if err := contracts.commit(batch); err != nil {
return err
}
if err := bc.memStore.PutBatch(batch); err != nil { if err := bc.memStore.PutBatch(batch); err != nil {
return err return err
} }
@ -643,6 +647,33 @@ func getAssetStateFromStore(s storage.Store, assetID util.Uint256) *AssetState {
return &a return &a
} }
// GetContractState returns contract by its script hash.
func (bc *Blockchain) GetContractState(hash util.Uint160) *ContractState {
cs := getContractStateFromStore(bc.memStore, hash)
if cs == nil {
cs = getContractStateFromStore(bc.Store, hash)
}
return cs
}
// getContractStateFromStore returns contract state as recorded in the given
// store by the given script hash.
func getContractStateFromStore(s storage.Store, hash util.Uint160) *ContractState {
key := storage.AppendPrefix(storage.STContract, hash.Bytes())
contractBytes, err := s.Get(key)
if err != nil {
return nil
}
var c ContractState
r := io.NewBinReaderFromBuf(contractBytes)
c.DecodeBinary(r)
if r.Err != nil || c.ScriptHash() != hash {
return nil
}
return &c
}
// GetAccountState returns the account state from its script hash // GetAccountState returns the account state from its script hash
func (bc *Blockchain) GetAccountState(scriptHash util.Uint160) *AccountState { func (bc *Blockchain) GetAccountState(scriptHash util.Uint160) *AccountState {
as, err := getAccountStateFromStore(bc.memStore, scriptHash) as, err := getAccountStateFromStore(bc.memStore, scriptHash)
@ -1001,20 +1032,30 @@ func (bc *Blockchain) VerifyWitnesses(t *transaction.Transaction) error {
vm := vm.New(vm.ModeMute) vm := vm.New(vm.ModeMute)
vm.SetCheckedHash(t.VerificationHash().Bytes()) vm.SetCheckedHash(t.VerificationHash().Bytes())
vm.SetScriptGetter(func(hash util.Uint160) []byte {
cs := bc.GetContractState(hash)
if cs == nil {
return nil
}
return cs.Script
})
vm.LoadScript(verification) vm.LoadScript(verification)
vm.LoadScript(witnesses[i].InvocationScript) vm.LoadScript(witnesses[i].InvocationScript)
vm.Run() vm.Run()
if vm.HasFailed() { if vm.HasFailed() {
return errors.Errorf("vm failed to execute the script") return errors.Errorf("vm failed to execute the script")
} }
res := vm.PopResult() resEl := vm.Estack().Pop()
switch res.(type) { if resEl != nil {
case bool: res, err := resEl.TryBool()
if !(res.(bool)) { if err != nil {
return err
}
if !res {
return errors.Errorf("signature check failed") return errors.Errorf("signature check failed")
} }
default: } else {
return errors.Errorf("vm returned non-boolean result") return errors.Errorf("no result returned from the script")
} }
} }

View file

@ -1,16 +1,22 @@
package core package core
import ( import (
"github.com/CityOfZion/neo-go/pkg/core/storage"
"github.com/CityOfZion/neo-go/pkg/crypto/hash"
"github.com/CityOfZion/neo-go/pkg/io"
"github.com/CityOfZion/neo-go/pkg/smartcontract" "github.com/CityOfZion/neo-go/pkg/smartcontract"
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
) )
// Contracts is a mapping between scripthash and ContractState.
type Contracts map[util.Uint160]*ContractState
// ContractState holds information about a smart contract in the NEO blockchain. // ContractState holds information about a smart contract in the NEO blockchain.
type ContractState struct { type ContractState struct {
Script []byte Script []byte
ParamList []smartcontract.ParamType ParamList []smartcontract.ParamType
ReturnType smartcontract.ParamType ReturnType smartcontract.ParamType
Properties []int Properties []byte
Name string Name string
CodeVersion string CodeVersion string
Author string Author string
@ -21,3 +27,69 @@ type ContractState struct {
scriptHash util.Uint160 scriptHash util.Uint160
} }
// commit flushes all contracts to the given storage.Batch.
func (a Contracts) commit(b storage.Batch) error {
buf := io.NewBufBinWriter()
for hash, contract := range a {
contract.EncodeBinary(buf.BinWriter)
if buf.Err != nil {
return buf.Err
}
key := storage.AppendPrefix(storage.STContract, hash.Bytes())
b.Put(key, buf.Bytes())
buf.Reset()
}
return nil
}
// DecodeBinary implements Serializable interface.
func (a *ContractState) DecodeBinary(br *io.BinReader) {
a.Script = br.ReadBytes()
paramBytes := br.ReadBytes()
a.ParamList = make([]smartcontract.ParamType, len(paramBytes))
for k := range paramBytes {
a.ParamList[k] = smartcontract.ParamType(paramBytes[k])
}
br.ReadLE(&a.ReturnType)
a.Properties = br.ReadBytes()
a.Name = br.ReadString()
a.CodeVersion = br.ReadString()
a.Author = br.ReadString()
a.Email = br.ReadString()
a.Description = br.ReadString()
br.ReadLE(&a.HasStorage)
br.ReadLE(&a.HasDynamicInvoke)
a.createHash()
}
// EncodeBinary implements Serializable interface.
func (a *ContractState) EncodeBinary(bw *io.BinWriter) {
bw.WriteBytes(a.Script)
bw.WriteVarUint(uint64(len(a.ParamList)))
for k := range a.ParamList {
bw.WriteLE(a.ParamList[k])
}
bw.WriteLE(a.ReturnType)
bw.WriteBytes(a.Properties)
bw.WriteString(a.Name)
bw.WriteString(a.CodeVersion)
bw.WriteString(a.Author)
bw.WriteString(a.Email)
bw.WriteString(a.Description)
bw.WriteLE(a.HasStorage)
bw.WriteLE(a.HasDynamicInvoke)
}
// ScriptHash returns a contract script hash.
func (a *ContractState) ScriptHash() util.Uint160 {
if a.scriptHash.Equals(util.Uint160{}) {
a.createHash()
}
return a.scriptHash
}
// createHash creates contract script hash.
func (a *ContractState) createHash() {
a.scriptHash = hash.Hash160(a.Script)
}

View file

@ -0,0 +1,39 @@
package core
import (
"testing"
"github.com/CityOfZion/neo-go/pkg/crypto/hash"
"github.com/CityOfZion/neo-go/pkg/io"
"github.com/CityOfZion/neo-go/pkg/smartcontract"
"github.com/stretchr/testify/assert"
)
func TestEncodeDecodeContractState(t *testing.T) {
script := []byte("testscript")
contract := &ContractState{
Script: script,
ParamList: []smartcontract.ParamType{smartcontract.StringType, smartcontract.IntegerType, smartcontract.Hash160Type},
ReturnType: smartcontract.BoolType,
Properties: []byte("smth"),
Name: "Contracto",
CodeVersion: "1.0.0",
Author: "Joe Random",
Email: "joe@example.com",
Description: "Test contract",
HasStorage: true,
HasDynamicInvoke: false,
}
assert.Equal(t, hash.Hash160(script), contract.ScriptHash())
buf := io.NewBufBinWriter()
contract.EncodeBinary(buf.BinWriter)
assert.Nil(t, buf.Err)
contractDecoded := &ContractState{}
r := io.NewBinReaderFromBuf(buf.Bytes())
contractDecoded.DecodeBinary(r)
assert.Nil(t, r.Err)
assert.Equal(t, contract, contractDecoded)
assert.Equal(t, contract.ScriptHash(), contractDecoded.ScriptHash())
}

View file

@ -3,7 +3,7 @@ package smartcontract
import "github.com/CityOfZion/neo-go/pkg/util" import "github.com/CityOfZion/neo-go/pkg/util"
// ParamType represent the Type of the contract parameter // ParamType represent the Type of the contract parameter
type ParamType int type ParamType byte
// A list of supported smart contract parameter types. // A list of supported smart contract parameter types.
const ( const (

View file

@ -13,7 +13,6 @@ import (
"log" "log"
"os" "os"
"strings" "strings"
"text/tabwriter"
"github.com/CityOfZion/neo-go/pkg/vm" "github.com/CityOfZion/neo-go/pkg/vm"
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
@ -108,25 +107,9 @@ func CompileAndInspect(src string) error {
return err return err
} }
w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0) v := vm.New(0)
fmt.Fprintln(w, "INDEX\tOPCODE\tDESC\t") v.LoadScript(b)
for i := 0; i <= len(b)-1; { v.PrintOps()
instr := vm.Instruction(b[i])
paramlength := 0
fmt.Fprintf(w, "%d\t0x%x\t%s\t\n", i, b[i], instr)
i++
if instr >= vm.PUSHBYTES1 && instr <= vm.PUSHBYTES75 {
paramlength = int(instr)
}
if instr == vm.JMP || instr == vm.JMPIF || instr == vm.JMPIFNOT || instr == vm.CALL {
paramlength = 2
}
for x := 0; x < paramlength; x++ {
fmt.Fprintf(w, "%d\t0x%x\t%s\t\n", i, b[i+1+x], string(b[i+1+x]))
}
i += paramlength
}
w.Flush()
return nil return nil
} }

View file

@ -1,7 +1,9 @@
package vm package vm
import ( import (
"encoding/binary" "errors"
"github.com/CityOfZion/neo-go/pkg/io"
) )
// Context represent the current execution context of the VM. // Context represent the current execution context of the VM.
@ -9,6 +11,9 @@ type Context struct {
// Instruction pointer. // Instruction pointer.
ip int ip int
// The next instruction pointer.
nextip int
// The raw program script. // The raw program script.
prog []byte prog []byte
@ -19,19 +24,62 @@ type Context struct {
// NewContext return a new Context object. // NewContext return a new Context object.
func NewContext(b []byte) *Context { func NewContext(b []byte) *Context {
return &Context{ return &Context{
ip: -1,
prog: b, prog: b,
breakPoints: []int{}, breakPoints: []int{},
} }
} }
// Next return the next instruction to execute. // Next returns the next instruction to execute with its parameter if any. After
func (c *Context) Next() Instruction { // its invocation the instruction pointer points to the instruction being
c.ip++ // returned.
func (c *Context) Next() (Instruction, []byte, error) {
c.ip = c.nextip
if c.ip >= len(c.prog) { if c.ip >= len(c.prog) {
return RET return RET, nil, nil
} }
return Instruction(c.prog[c.ip]) r := io.NewBinReaderFromBuf(c.prog[c.ip:])
var instrbyte byte
r.ReadLE(&instrbyte)
instr := Instruction(instrbyte)
c.nextip++
var numtoread int
switch instr {
case PUSHDATA1, SYSCALL:
var n byte
r.ReadLE(&n)
numtoread = int(n)
c.nextip++
case PUSHDATA2:
var n uint16
r.ReadLE(&n)
numtoread = int(n)
c.nextip += 2
case PUSHDATA4:
var n uint32
r.ReadLE(&n)
numtoread = int(n)
c.nextip += 4
case JMP, JMPIF, JMPIFNOT, CALL:
numtoread = 2
case APPCALL, TAILCALL:
numtoread = 20
default:
if instr >= PUSHBYTES1 && instr <= PUSHBYTES75 {
numtoread = int(instr)
} else {
// No parameters, can just return.
return instr, nil, nil
}
}
parameter := make([]byte, numtoread)
r.ReadLE(parameter)
if r.Err != nil {
return instr, nil, errors.New("failed to read instruction parameter")
}
c.nextip += numtoread
return instr, parameter, nil
} }
// IP returns the absolute instruction without taking 0 into account. // IP returns the absolute instruction without taking 0 into account.
@ -48,19 +96,14 @@ func (c *Context) LenInstr() int {
// CurrInstr returns the current instruction and opcode. // CurrInstr returns the current instruction and opcode.
func (c *Context) CurrInstr() (int, Instruction) { func (c *Context) CurrInstr() (int, Instruction) {
if c.ip < 0 {
return c.ip, NOP
}
return c.ip, Instruction(c.prog[c.ip]) return c.ip, Instruction(c.prog[c.ip])
} }
// Copy returns an new exact copy of c. // Copy returns an new exact copy of c.
func (c *Context) Copy() *Context { func (c *Context) Copy() *Context {
return &Context{ ctx := new(Context)
ip: c.ip, *ctx = *c
prog: c.prog, return ctx
breakPoints: c.breakPoints,
}
} }
// Program returns the loaded program. // Program returns the loaded program.
@ -85,44 +128,3 @@ func (c *Context) atBreakPoint() bool {
func (c *Context) String() string { func (c *Context) String() string {
return "execution context" return "execution context"
} }
func (c *Context) readUint32() uint32 {
start, end := c.IP(), c.IP()+4
if end > len(c.prog) {
panic("failed to read uint32 parameter")
}
val := binary.LittleEndian.Uint32(c.prog[start:end])
c.ip += 4
return val
}
func (c *Context) readUint16() uint16 {
start, end := c.IP(), c.IP()+2
if end > len(c.prog) {
panic("failed to read uint16 parameter")
}
val := binary.LittleEndian.Uint16(c.prog[start:end])
c.ip += 2
return val
}
func (c *Context) readByte() byte {
return c.readBytes(1)[0]
}
func (c *Context) readBytes(n int) []byte {
start, end := c.IP(), c.IP()+n
if end > len(c.prog) {
return nil
}
out := make([]byte, n)
copy(out, c.prog[start:end])
c.ip += n
return out
}
func (c *Context) readVarBytes() []byte {
n := c.readByte()
return c.readBytes(int(n))
}

View file

@ -10,13 +10,13 @@ type InteropFunc func(vm *VM) error
// runtimeLog will handle the syscall "Neo.Runtime.Log" for printing and logging stuff. // runtimeLog will handle the syscall "Neo.Runtime.Log" for printing and logging stuff.
func runtimeLog(vm *VM) error { func runtimeLog(vm *VM) error {
item := vm.Estack().Pop() item := vm.Estack().Pop()
fmt.Printf("NEO-GO-VM (log) > %s\n", item.value.Value()) fmt.Printf("NEO-GO-VM (log) > %s\n", item.Value())
return nil return nil
} }
// runtimeNotify will handle the syscall "Neo.Runtime.Notify" for printing and logging stuff. // runtimeNotify will handle the syscall "Neo.Runtime.Notify" for printing and logging stuff.
func runtimeNotify(vm *VM) error { func runtimeNotify(vm *VM) error {
item := vm.Estack().Pop() item := vm.Estack().Pop()
fmt.Printf("NEO-GO-VM (notify) > %s\n", item.value.Value()) fmt.Printf("NEO-GO-VM (notify) > %s\n", item.Value())
return nil return nil
} }

View file

@ -58,6 +58,11 @@ func (e *Element) Prev() *Element {
return nil return nil
} }
// Value returns value of the StackItem contained in the element.
func (e *Element) Value() interface{} {
return e.value.Value()
}
// BigInt attempts to get the underlying value of the element as a big integer. // BigInt attempts to get the underlying value of the element as a big integer.
// Will panic if the assertion failed which will be caught by the VM. // Will panic if the assertion failed which will be caught by the VM.
func (e *Element) BigInt() *big.Int { func (e *Element) BigInt() *big.Int {
@ -75,28 +80,40 @@ func (e *Element) BigInt() *big.Int {
} }
} }
// Bool attempts to get the underlying value of the element as a boolean. // TryBool attempts to get the underlying value of the element as a boolean.
// Will panic if the assertion failed which will be caught by the VM. // Returns error if can't convert value to boolean type.
func (e *Element) Bool() bool { func (e *Element) TryBool() (bool, error) {
switch t := e.value.(type) { switch t := e.value.(type) {
case *BigIntegerItem: case *BigIntegerItem:
return t.value.Int64() != 0 return t.value.Int64() != 0, nil
case *BoolItem: case *BoolItem:
return t.value return t.value, nil
case *ArrayItem, *StructItem: case *ArrayItem, *StructItem:
return true return true, nil
case *ByteArrayItem: case *ByteArrayItem:
for _, b := range t.value { for _, b := range t.value {
if b != 0 { if b != 0 {
return true return true, nil
} }
} }
return false return false, nil
case *InteropItem:
return t.value != nil, nil
default: default:
panic("can't convert to bool: " + t.String()) return false, fmt.Errorf("can't convert to bool: " + t.String())
} }
} }
// Bool attempts to get the underlying value of the element as a boolean.
// Will panic if the assertion failed which will be caught by the VM.
func (e *Element) Bool() bool {
val, err := e.TryBool()
if err != nil {
panic(err)
}
return val
}
// Bytes attempts to get the underlying value of the element as a byte array. // Bytes attempts to get the underlying value of the element as a byte array.
// Will panic if the assertion failed which will be caught by the VM. // Will panic if the assertion failed which will be caught by the VM.
func (e *Element) Bytes() []byte { func (e *Element) Bytes() []byte {

View file

@ -23,6 +23,10 @@ func makeStackItem(v interface{}) StackItem {
return &BigIntegerItem{ return &BigIntegerItem{
value: big.NewInt(val), value: big.NewInt(val),
} }
case uint32:
return &BigIntegerItem{
value: big.NewInt(int64(val)),
}
case []byte: case []byte:
return &ByteArrayItem{ return &ByteArrayItem{
value: val, value: val,
@ -248,3 +252,30 @@ func toMapKey(key StackItem) interface{} {
panic("wrong key type") panic("wrong key type")
} }
} }
// InteropItem represents interop data on the stack.
type InteropItem struct {
value interface{}
}
// NewInteropItem returns new InteropItem object.
func NewInteropItem(value interface{}) *InteropItem {
return &InteropItem{
value: value,
}
}
// Value implements StackItem interface.
func (i *InteropItem) Value() interface{} {
return i.value
}
// String implements stringer interface.
func (i *InteropItem) String() string {
return "InteropItem"
}
// MarshalJSON implements the json.Marshaler interface.
func (i *InteropItem) MarshalJSON() ([]byte, error) {
return json.Marshal(i.value)
}

View file

@ -45,9 +45,9 @@ func vmAndCompile(t *testing.T, src string) *vm.VM {
vm := vm.New(vm.ModeMute) vm := vm.New(vm.ModeMute)
storePlugin := newStoragePlugin() storePlugin := newStoragePlugin()
vm.RegisterInteropFunc("Neo.Storage.Get", storePlugin.Get) vm.RegisterInteropFunc("Neo.Storage.Get", storePlugin.Get, 1)
vm.RegisterInteropFunc("Neo.Storage.Put", storePlugin.Put) vm.RegisterInteropFunc("Neo.Storage.Put", storePlugin.Put, 1)
vm.RegisterInteropFunc("Neo.Storage.GetContext", storePlugin.GetContext) vm.RegisterInteropFunc("Neo.Storage.GetContext", storePlugin.GetContext, 1)
b, err := compiler.Compile(strings.NewReader(src), &compiler.Options{}) b, err := compiler.Compile(strings.NewReader(src), &compiler.Options{})
if err != nil { if err != nil {

View file

@ -2,6 +2,7 @@ package vm
import ( import (
"crypto/sha1" "crypto/sha1"
"encoding/binary"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -9,6 +10,7 @@ import (
"os" "os"
"reflect" "reflect"
"text/tabwriter" "text/tabwriter"
"unicode/utf8"
"github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/crypto/hash"
"github.com/CityOfZion/neo-go/pkg/crypto/keys" "github.com/CityOfZion/neo-go/pkg/crypto/keys"
@ -24,8 +26,10 @@ var (
) )
const ( const (
maxSHLArg = 256 // MaxArraySize is the maximum array size allowed in the VM.
minSHLArg = -256 MaxArraySize = 1024
maxSHLArg = 256
minSHLArg = -256
) )
// VM represents the virtual machine. // VM represents the virtual machine.
@ -33,10 +37,10 @@ type VM struct {
state State state State
// registered interop hooks. // registered interop hooks.
interop map[string]InteropFunc interop map[string]InteropFuncPrice
// scripts loaded in memory. // callback to get scripts.
scripts map[util.Uint160][]byte getScript func(util.Uint160) []byte
istack *Stack // invocation stack. istack *Stack // invocation stack.
estack *Stack // execution stack. estack *Stack // execution stack.
@ -48,30 +52,45 @@ type VM struct {
checkhash []byte checkhash []byte
} }
// InteropFuncPrice represents an interop function with a price.
type InteropFuncPrice struct {
Func InteropFunc
Price int
}
// New returns a new VM object ready to load .avm bytecode scripts. // New returns a new VM object ready to load .avm bytecode scripts.
func New(mode Mode) *VM { func New(mode Mode) *VM {
vm := &VM{ vm := &VM{
interop: make(map[string]InteropFunc), interop: make(map[string]InteropFuncPrice),
scripts: make(map[util.Uint160][]byte), getScript: nil,
state: haltState, state: haltState,
istack: NewStack("invocation"), istack: NewStack("invocation"),
estack: NewStack("evaluation"), estack: NewStack("evaluation"),
astack: NewStack("alt"), astack: NewStack("alt"),
} }
if mode == ModeMute { if mode == ModeMute {
vm.mute = true vm.mute = true
} }
// Register native interop hooks. // Register native interop hooks.
vm.RegisterInteropFunc("Neo.Runtime.Log", runtimeLog) vm.RegisterInteropFunc("Neo.Runtime.Log", runtimeLog, 1)
vm.RegisterInteropFunc("Neo.Runtime.Notify", runtimeNotify) vm.RegisterInteropFunc("Neo.Runtime.Notify", runtimeNotify, 1)
return vm return vm
} }
// RegisterInteropFunc will register the given InteropFunc to the VM. // RegisterInteropFunc will register the given InteropFunc to the VM.
func (v *VM) RegisterInteropFunc(name string, f InteropFunc) { func (v *VM) RegisterInteropFunc(name string, f InteropFunc, price int) {
v.interop[name] = f v.interop[name] = InteropFuncPrice{f, price}
}
// RegisterInteropFuncs will register all interop functions passed in a map in
// the VM. Effectively it's a batched version of RegisterInteropFunc.
func (v *VM) RegisterInteropFuncs(interops map[string]InteropFuncPrice) {
// We allow reregistration here.
for name, funPrice := range interops {
v.interop[name] = funPrice
}
} }
// Estack will return the evaluation stack so interop hooks can utilize this. // Estack will return the evaluation stack so interop hooks can utilize this.
@ -101,19 +120,45 @@ func (v *VM) LoadArgs(method []byte, args []StackItem) {
// PrintOps will print the opcodes of the current loaded program to stdout. // PrintOps will print the opcodes of the current loaded program to stdout.
func (v *VM) PrintOps() { func (v *VM) PrintOps() {
prog := v.Context().Program()
w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0) w := tabwriter.NewWriter(os.Stdout, 0, 0, 4, ' ', 0)
fmt.Fprintln(w, "INDEX\tOPCODE\tDESC\t") fmt.Fprintln(w, "INDEX\tOPCODE\tPARAMETER\t")
cursor := "" realctx := v.Context()
ip, _ := v.Context().CurrInstr() ctx := realctx.Copy()
for i := 0; i < len(prog); i++ { ctx.ip = 0
if i == ip { ctx.nextip = 0
for {
cursor := ""
instr, parameter, err := ctx.Next()
if ctx.ip == realctx.ip {
cursor = "<<" cursor = "<<"
} else {
cursor = ""
} }
fmt.Fprintf(w, "%d\t0x%2x\t%s\t%s\n", i, prog[i], Instruction(prog[i]).String(), cursor) if err != nil {
fmt.Fprintf(w, "%d\t%s\tERROR: %s\t%s\n", ctx.ip, instr, err, cursor)
break
}
var desc = ""
if parameter != nil {
switch instr {
case JMP, JMPIF, JMPIFNOT, CALL:
offset := int16(binary.LittleEndian.Uint16(parameter))
desc = fmt.Sprintf("%d (%d/%x)", ctx.ip+int(offset), offset, parameter)
case SYSCALL:
desc = fmt.Sprintf("%q", parameter)
case APPCALL, TAILCALL:
desc = fmt.Sprintf("%x", parameter)
default:
if utf8.Valid(parameter) {
desc = fmt.Sprintf("%x (%q)", parameter, parameter)
} else {
desc = fmt.Sprintf("%x", parameter)
}
}
}
fmt.Fprintf(w, "%d\t%s\t%s\t%s\n", ctx.ip, instr, desc, cursor)
if ctx.nextip >= len(ctx.prog) {
break
}
} }
w.Flush() w.Flush()
} }
@ -164,13 +209,17 @@ func (v *VM) Context() *Context {
if v.istack.Len() == 0 { if v.istack.Len() == 0 {
return nil return nil
} }
return v.istack.Peek(0).value.Value().(*Context) return v.istack.Peek(0).Value().(*Context)
} }
// PopResult is used to pop the first item of the evaluation stack. This allows // PopResult is used to pop the first item of the evaluation stack. This allows
// us to test compiler and vm in a bi-directional way. // us to test compiler and vm in a bi-directional way.
func (v *VM) PopResult() interface{} { func (v *VM) PopResult() interface{} {
return v.estack.Pop().value.Value() e := v.estack.Pop()
if e != nil {
return e.Value()
}
return nil
} }
// Stack returns json formatted representation of the given stack. // Stack returns json formatted representation of the given stack.
@ -203,7 +252,15 @@ func (v *VM) Run() {
v.state = noneState v.state = noneState
for { for {
// check for breakpoint before executing the next instruction
ctx := v.Context()
if ctx != nil && ctx.atBreakPoint() {
v.state |= breakState
}
switch { switch {
case v.state.HasFlag(faultState):
fmt.Println("FAULT")
return
case v.state.HasFlag(haltState): case v.state.HasFlag(haltState):
if !v.mute { if !v.mute {
fmt.Println(v.Stack("estack")) fmt.Println(v.Stack("estack"))
@ -214,9 +271,6 @@ func (v *VM) Run() {
i, op := ctx.CurrInstr() i, op := ctx.CurrInstr()
fmt.Printf("at breakpoint %d (%s)\n", i, op.String()) fmt.Printf("at breakpoint %d (%s)\n", i, op.String())
return return
case v.state.HasFlag(faultState):
fmt.Println("FAULT")
return
case v.state == noneState: case v.state == noneState:
v.Step() v.Step()
} }
@ -226,14 +280,13 @@ func (v *VM) Run() {
// Step 1 instruction in the program. // Step 1 instruction in the program.
func (v *VM) Step() { func (v *VM) Step() {
ctx := v.Context() ctx := v.Context()
op := ctx.Next() op, param, err := ctx.Next()
v.execute(ctx, op) if err != nil {
log.Printf("error encountered at instruction %d (%s)", ctx.ip, op)
// re-peek the context as it could been changed during execution. log.Println(err)
cctx := v.Context() v.state = faultState
if cctx != nil && cctx.atBreakPoint() {
v.state = breakState
} }
v.execute(ctx, op, param)
} }
// HasFailed returns whether VM is in the failed state now. Usually used to // HasFailed returns whether VM is in the failed state now. Usually used to
@ -248,8 +301,13 @@ func (v *VM) SetCheckedHash(h []byte) {
copy(v.checkhash, h) copy(v.checkhash, h)
} }
// SetScriptGetter sets the script getter for CALL instructions.
func (v *VM) SetScriptGetter(gs func(util.Uint160) []byte) {
v.getScript = gs
}
// execute performs an instruction cycle in the VM. Acting on the instruction (opcode). // execute performs an instruction cycle in the VM. Acting on the instruction (opcode).
func (v *VM) execute(ctx *Context, op Instruction) { func (v *VM) execute(ctx *Context, op Instruction, parameter []byte) {
// Instead of polluting the whole VM logic with error handling, we will recover // Instead of polluting the whole VM logic with error handling, we will recover
// each panic at a central point, putting the VM in a fault state. // each panic at a central point, putting the VM in a fault state.
defer func() { defer func() {
@ -261,11 +319,7 @@ func (v *VM) execute(ctx *Context, op Instruction) {
}() }()
if op >= PUSHBYTES1 && op <= PUSHBYTES75 { if op >= PUSHBYTES1 && op <= PUSHBYTES75 {
b := ctx.readBytes(int(op)) v.estack.PushVal(parameter)
if b == nil {
panic("failed to read instruction parameter")
}
v.estack.PushVal(b)
return return
} }
@ -279,29 +333,8 @@ func (v *VM) execute(ctx *Context, op Instruction) {
case PUSH0: case PUSH0:
v.estack.PushVal([]byte{}) v.estack.PushVal([]byte{})
case PUSHDATA1: case PUSHDATA1, PUSHDATA2, PUSHDATA4:
n := ctx.readByte() v.estack.PushVal(parameter)
b := ctx.readBytes(int(n))
if b == nil {
panic("failed to read instruction parameter")
}
v.estack.PushVal(b)
case PUSHDATA2:
n := ctx.readUint16()
b := ctx.readBytes(int(n))
if b == nil {
panic("failed to read instruction parameter")
}
v.estack.PushVal(b)
case PUSHDATA4:
n := ctx.readUint32()
b := ctx.readBytes(int(n))
if b == nil {
panic("failed to read instruction parameter")
}
v.estack.PushVal(b)
// Stack operations. // Stack operations.
case TOALTSTACK: case TOALTSTACK:
@ -801,7 +834,7 @@ func (v *VM) execute(ctx *Context, op Instruction) {
elem := v.estack.Pop() elem := v.estack.Pop()
// Cause there is no native (byte) item type here, hence we need to check // Cause there is no native (byte) item type here, hence we need to check
// the type of the item for array size operations. // the type of the item for array size operations.
switch t := elem.value.Value().(type) { switch t := elem.Value().(type) {
case []StackItem: case []StackItem:
v.estack.PushVal(len(t)) v.estack.PushVal(len(t))
case map[interface{}]StackItem: case map[interface{}]StackItem:
@ -817,8 +850,8 @@ func (v *VM) execute(ctx *Context, op Instruction) {
case JMP, JMPIF, JMPIFNOT: case JMP, JMPIF, JMPIFNOT:
var ( var (
rOffset = int16(ctx.readUint16()) rOffset = int16(binary.LittleEndian.Uint16(parameter))
offset = ctx.ip + int(rOffset) - 3 // sizeOf(int16 + uint8) offset = ctx.ip + int(rOffset)
) )
if offset < 0 || offset > len(ctx.prog) { if offset < 0 || offset > len(ctx.prog) {
panic(fmt.Sprintf("JMP: invalid offset %d ip at %d", offset, ctx.ip)) panic(fmt.Sprintf("JMP: invalid offset %d ip at %d", offset, ctx.ip))
@ -831,36 +864,34 @@ func (v *VM) execute(ctx *Context, op Instruction) {
} }
} }
if cond { if cond {
ctx.ip = offset ctx.nextip = offset
} }
case CALL: case CALL:
v.istack.PushVal(ctx.Copy()) v.istack.PushVal(ctx.Copy())
ctx.ip += 2 v.execute(v.Context(), JMP, parameter)
v.execute(v.Context(), JMP)
case SYSCALL: case SYSCALL:
api := ctx.readVarBytes() ifunc, ok := v.interop[string(parameter)]
ifunc, ok := v.interop[string(api)]
if !ok { if !ok {
panic(fmt.Sprintf("interop hook (%s) not registered", api)) panic(fmt.Sprintf("interop hook (%q) not registered", parameter))
} }
if err := ifunc(v); err != nil { if err := ifunc.Func(v); err != nil {
panic(fmt.Sprintf("failed to invoke syscall: %s", err)) panic(fmt.Sprintf("failed to invoke syscall: %s", err))
} }
case APPCALL, TAILCALL: case APPCALL, TAILCALL:
if len(v.scripts) == 0 { if v.getScript == nil {
panic("script table is empty") panic("no getScript callback is set up")
} }
hash, err := util.Uint160DecodeBytes(ctx.readBytes(20)) hash, err := util.Uint160DecodeBytes(parameter)
if err != nil { if err != nil {
panic(err) panic(err)
} }
script, ok := v.scripts[hash] script := v.getScript(hash)
if !ok { if script == nil {
panic("could not find script") panic("could not find script")
} }

View file

@ -18,7 +18,7 @@ func TestInteropHook(t *testing.T) {
v.RegisterInteropFunc("foo", func(evm *VM) error { v.RegisterInteropFunc("foo", func(evm *VM) error {
evm.Estack().PushVal(1) evm.Estack().PushVal(1)
return nil return nil
}) }, 1)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
EmitSyscall(buf, "foo") EmitSyscall(buf, "foo")
@ -33,7 +33,7 @@ func TestInteropHook(t *testing.T) {
func TestRegisterInterop(t *testing.T) { func TestRegisterInterop(t *testing.T) {
v := New(ModeMute) v := New(ModeMute)
currRegistered := len(v.interop) currRegistered := len(v.interop)
v.RegisterInteropFunc("foo", func(evm *VM) error { return nil }) v.RegisterInteropFunc("foo", func(evm *VM) error { return nil }, 1)
assert.Equal(t, currRegistered+1, len(v.interop)) assert.Equal(t, currRegistered+1, len(v.interop))
_, ok := v.interop["foo"] _, ok := v.interop["foo"]
assert.Equal(t, true, ok) assert.Equal(t, true, ok)
@ -54,7 +54,7 @@ func TestPushBytes1to75(t *testing.T) {
assert.IsType(t, elem.Bytes(), b) assert.IsType(t, elem.Bytes(), b)
assert.Equal(t, 0, vm.estack.Len()) assert.Equal(t, 0, vm.estack.Len())
vm.execute(nil, RET) vm.execute(nil, RET, nil)
assert.Equal(t, 0, vm.astack.Len()) assert.Equal(t, 0, vm.astack.Len())
assert.Equal(t, 0, vm.istack.Len()) assert.Equal(t, 0, vm.istack.Len())
@ -1000,7 +1000,12 @@ func TestAppCall(t *testing.T) {
prog = append(prog, byte(RET)) prog = append(prog, byte(RET))
vm := load(prog) vm := load(prog)
vm.scripts[hash] = makeProgram(DEPTH) vm.SetScriptGetter(func(in util.Uint160) []byte {
if in.Equals(hash) {
return makeProgram(DEPTH)
}
return nil
})
vm.estack.PushVal(2) vm.estack.PushVal(2)
vm.Run() vm.Run()