*: move syscall handling out of VM

Remove interop-related structures from the `vm` package.

Signed-off-by: Evgenii Stratonikov <evgeniy@nspcc.ru>
This commit is contained in:
Evgenii Stratonikov 2020-07-28 16:38:00 +03:00
parent f24e707ea1
commit 51ae12e4fd
17 changed files with 195 additions and 257 deletions

View file

@ -1,6 +1,7 @@
package compiler_test package compiler_test
import ( import (
"errors"
"fmt" "fmt"
"math/big" "math/big"
"testing" "testing"
@ -20,7 +21,7 @@ func TestPanic(t *testing.T) {
var logs []string var logs []string
src := getPanicSource(true, `"execution fault"`) src := getPanicSource(true, `"execution fault"`)
v := vmAndCompile(t, src) v := vmAndCompile(t, src)
v.RegisterInteropGetter(logGetter(&logs)) v.SyscallHandler = getLogHandler(&logs)
require.Error(t, v.Run()) require.Error(t, v.Run())
require.True(t, v.HasFailed()) require.True(t, v.HasFailed())
@ -32,7 +33,7 @@ func TestPanic(t *testing.T) {
var logs []string var logs []string
src := getPanicSource(true, `nil`) src := getPanicSource(true, `nil`)
v := vmAndCompile(t, src) v := vmAndCompile(t, src)
v.RegisterInteropGetter(logGetter(&logs)) v.SyscallHandler = getLogHandler(&logs)
require.Error(t, v.Run()) require.Error(t, v.Run())
require.True(t, v.HasFailed()) require.True(t, v.HasFailed())
@ -54,19 +55,15 @@ func getPanicSource(need bool, message string) string {
`, need, message) `, need, message)
} }
func logGetter(logs *[]string) vm.InteropGetterFunc { func getLogHandler(logs *[]string) vm.SyscallHandler {
logID := emit.InteropNameToID([]byte("System.Runtime.Log")) logID := emit.InteropNameToID([]byte("System.Runtime.Log"))
return func(id uint32) *vm.InteropFuncPrice { return func(v *vm.VM, id uint32) error {
if id != logID { if id != logID {
return nil return errors.New("syscall not found")
} }
return &vm.InteropFuncPrice{ msg := string(v.Estack().Pop().Bytes())
Func: func(v *vm.VM) error { *logs = append(*logs, msg)
msg := string(v.Estack().Pop().Bytes()) return nil
*logs = append(*logs, msg)
return nil
},
}
} }
} }

View file

@ -23,7 +23,8 @@ func TestSHA256(t *testing.T) {
` `
v := vmAndCompile(t, src) v := vmAndCompile(t, src)
ic := &interop.Context{Trigger: trigger.Verification} ic := &interop.Context{Trigger: trigger.Verification}
v.RegisterInteropGetter(crypto.GetInterop(ic)) crypto.Register(ic)
v.SyscallHandler = ic.SyscallHandler
require.NoError(t, v.Run()) require.NoError(t, v.Run())
require.True(t, v.Estack().Len() >= 1) require.True(t, v.Estack().Len() >= 1)

View file

@ -1,6 +1,7 @@
package compiler_test package compiler_test
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@ -68,7 +69,8 @@ func vmAndCompileInterop(t *testing.T, src string) (*vm.VM, *storagePlugin) {
vm := vm.New() vm := vm.New()
storePlugin := newStoragePlugin() storePlugin := newStoragePlugin()
vm.RegisterInteropGetter(storePlugin.getInterop) vm.GasLimit = -1
vm.SyscallHandler = storePlugin.syscallHandler
b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src)) b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src))
require.NoError(t, err) require.NoError(t, err)
@ -97,14 +99,14 @@ func invokeMethod(t *testing.T, method string, script []byte, v *vm.VM, di *comp
type storagePlugin struct { type storagePlugin struct {
mem map[string][]byte mem map[string][]byte
interops map[uint32]vm.InteropFunc interops map[uint32]func(v *vm.VM) error
events []state.NotificationEvent events []state.NotificationEvent
} }
func newStoragePlugin() *storagePlugin { func newStoragePlugin() *storagePlugin {
s := &storagePlugin{ s := &storagePlugin{
mem: make(map[string][]byte), mem: make(map[string][]byte),
interops: make(map[uint32]vm.InteropFunc), interops: make(map[uint32]func(v *vm.VM) error),
} }
s.interops[emit.InteropNameToID([]byte("System.Storage.Get"))] = s.Get s.interops[emit.InteropNameToID([]byte("System.Storage.Get"))] = s.Get
s.interops[emit.InteropNameToID([]byte("System.Storage.Put"))] = s.Put s.interops[emit.InteropNameToID([]byte("System.Storage.Put"))] = s.Put
@ -114,12 +116,15 @@ func newStoragePlugin() *storagePlugin {
} }
func (s *storagePlugin) getInterop(id uint32) *vm.InteropFuncPrice { func (s *storagePlugin) syscallHandler(v *vm.VM, id uint32) error {
f := s.interops[id] f := s.interops[id]
if f != nil { if f != nil {
return &vm.InteropFuncPrice{Func: f, Price: 1} if !v.AddGas(1) {
return errors.New("insufficient amount of gas")
}
return f(v)
} }
return nil return errors.New("syscall not found")
} }
func (s *storagePlugin) Notify(v *vm.VM) error { func (s *storagePlugin) Notify(v *vm.VM) error {

View file

@ -11,8 +11,8 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -222,9 +222,10 @@ func (p *Payload) Verify(scriptHash util.Uint160) bool {
return false return false
} }
v := vm.New() ic := &interop.Context{Trigger: trigger.Verification, Container: p}
crypto.Register(ic)
v := ic.SpawnVM()
v.GasLimit = payloadGasLimit v.GasLimit = payloadGasLimit
v.RegisterInteropGetter(crypto.GetInterop(&interop.Context{Container: p}))
v.LoadScript(verification) v.LoadScript(verification)
v.LoadScript(p.Witness.InvocationScript) v.LoadScript(p.Witness.InvocationScript)

View file

@ -566,8 +566,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
if block.Index > 0 { if block.Index > 0 {
systemInterop := bc.newInteropContext(trigger.System, cache, block, nil) systemInterop := bc.newInteropContext(trigger.System, cache, block, nil)
v := SpawnVM(systemInterop) v := systemInterop.SpawnVM()
v.GasLimit = -1
v.LoadScriptWithFlags(bc.contracts.GetPersistScript(), smartcontract.AllowModifyStates|smartcontract.AllowCall) v.LoadScriptWithFlags(bc.contracts.GetPersistScript(), smartcontract.AllowModifyStates|smartcontract.AllowCall)
v.SetPriceGetter(getPrice) v.SetPriceGetter(getPrice)
if err := v.Run(); err != nil { if err := v.Run(); err != nil {
@ -599,7 +598,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
} }
systemInterop := bc.newInteropContext(trigger.Application, cache, block, tx) systemInterop := bc.newInteropContext(trigger.Application, cache, block, tx)
v := SpawnVM(systemInterop) v := systemInterop.SpawnVM()
v.LoadScriptWithFlags(tx.Script, smartcontract.All) v.LoadScriptWithFlags(tx.Script, smartcontract.All)
v.SetPriceGetter(getPrice) v.SetPriceGetter(getPrice)
v.GasLimit = tx.SystemFee v.GasLimit = tx.SystemFee
@ -1277,7 +1276,7 @@ func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([
// GetTestVM returns a VM and a Store setup for a test run of some sort of code. // GetTestVM returns a VM and a Store setup for a test run of some sort of code.
func (bc *Blockchain) GetTestVM(tx *transaction.Transaction) *vm.VM { func (bc *Blockchain) GetTestVM(tx *transaction.Transaction) *vm.VM {
systemInterop := bc.newInteropContext(trigger.Application, bc.dao, nil, tx) systemInterop := bc.newInteropContext(trigger.Application, bc.dao, nil, tx)
vm := SpawnVM(systemInterop) vm := systemInterop.SpawnVM()
vm.SetPriceGetter(getPrice) vm.SetPriceGetter(getPrice)
return vm return vm
} }
@ -1310,7 +1309,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa
gas = gasPolicy gas = gasPolicy
} }
vm := SpawnVM(interopCtx) vm := interopCtx.SpawnVM()
vm.SetPriceGetter(getPrice) vm.SetPriceGetter(getPrice)
vm.GasLimit = gas vm.GasLimit = gas
vm.LoadScriptWithFlags(verification, smartcontract.ReadOnly) vm.LoadScriptWithFlags(verification, smartcontract.ReadOnly)
@ -1413,6 +1412,7 @@ func (bc *Blockchain) secondsPerBlock() int {
func (bc *Blockchain) newInteropContext(trigger trigger.Type, d dao.DAO, block *block.Block, tx *transaction.Transaction) *interop.Context { func (bc *Blockchain) newInteropContext(trigger trigger.Type, d dao.DAO, block *block.Block, tx *transaction.Transaction) *interop.Context {
ic := interop.NewContext(trigger, bc, d, bc.contracts.Contracts, block, tx, bc.log) ic := interop.NewContext(trigger, bc, d, bc.contracts.Contracts, block, tx, bc.log)
ic.Functions = [][]interop.Function{systemInterops, neoInterops}
switch { switch {
case tx != nil: case tx != nil:
ic.Container = tx ic.Container = tx

View file

@ -9,14 +9,6 @@ import (
const StoragePrice = 100000 const StoragePrice = 100000
// getPrice returns a price for executing op with the provided parameter. // getPrice returns a price for executing op with the provided parameter.
// Some SYSCALLs have variable price depending on their arguments.
func getPrice(v *vm.VM, op opcode.Opcode, parameter []byte) int64 { func getPrice(v *vm.VM, op opcode.Opcode, parameter []byte) int64 {
if op == opcode.SYSCALL {
interopID := vm.GetInteropID(parameter)
ifunc := v.GetInteropByID(interopID)
if ifunc != nil && ifunc.Price > 0 {
return ifunc.Price
}
}
return opcodePrice(op) return opcodePrice(op)
} }

View file

@ -1,6 +1,10 @@
package interop package interop
import ( import (
"errors"
"fmt"
"sort"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/dao"
@ -32,6 +36,7 @@ type Context struct {
Log *zap.Logger Log *zap.Logger
Invocations map[util.Uint160]int Invocations map[util.Uint160]int
ScriptGetter vm.ScriptHashGetter ScriptGetter vm.ScriptHashGetter
Functions [][]Function
} }
// NewContext returns new interop context. // NewContext returns new interop context.
@ -48,6 +53,8 @@ func NewContext(trigger trigger.Type, bc blockchainer.Blockchainer, d dao.DAO, n
Notifications: nes, Notifications: nes,
Log: log, Log: log,
Invocations: make(map[util.Uint160]int), Invocations: make(map[util.Uint160]int),
// Functions is a slice of slices of interops sorted by ID.
Functions: [][]Function{},
} }
} }
@ -124,3 +131,46 @@ func (c *ContractMD) AddEvent(name string, ps ...manifest.Parameter) {
Parameters: ps, Parameters: ps,
}) })
} }
// Sort sorts interop functions by id.
func Sort(fs []Function) {
sort.Slice(fs, func(i, j int) bool { return fs[i].ID < fs[j].ID })
}
// GetFunction returns metadata for interop with the specified id.
func (ic *Context) GetFunction(id uint32) *Function {
for _, slice := range ic.Functions {
n := sort.Search(len(slice), func(i int) bool {
return slice[i].ID >= id
})
if n < len(slice) && slice[n].ID == id {
return &slice[n]
}
}
return nil
}
// SyscallHandler handles syscall with id.
func (ic *Context) SyscallHandler(v *vm.VM, id uint32) error {
f := ic.GetFunction(id)
if f == nil {
return errors.New("syscall not found")
}
cf := v.Context().GetCallFlags()
if !cf.Has(f.RequiredFlags) {
return fmt.Errorf("missing call flags: %05b vs %05b", cf, f.RequiredFlags)
}
if !v.AddGas(f.Price) {
return errors.New("insufficient amount of gas")
}
return f.Func(ic, v)
}
// SpawnVM spawns new VM with the specified gas limit.
func (ic *Context) SpawnVM() *vm.VM {
v := vm.NewWithTrigger(ic.Trigger)
v.GasLimit = -1
v.SyscallHandler = ic.SyscallHandler
ic.ScriptGetter = v
return v
}

View file

@ -59,10 +59,9 @@ func initCHECKMULTISIGVM(t *testing.T, n int, ik, is []int) *vm.VM {
buf[0] = byte(opcode.SYSCALL) buf[0] = byte(opcode.SYSCALL)
binary.LittleEndian.PutUint32(buf[1:], ecdsaSecp256r1CheckMultisigID) binary.LittleEndian.PutUint32(buf[1:], ecdsaSecp256r1CheckMultisigID)
v := vm.New()
v.GasLimit = -1
ic := &interop.Context{Trigger: trigger.Verification} ic := &interop.Context{Trigger: trigger.Verification}
v.RegisterInteropGetter(GetInterop(ic)) Register(ic)
v := ic.SpawnVM()
v.LoadScript(buf) v.LoadScript(buf)
msg := []byte("NEO - An Open Network For Smart Economy") msg := []byte("NEO - An Open Network For Smart Economy")

View file

@ -7,7 +7,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -20,9 +19,9 @@ func TestSHA256(t *testing.T) {
emit.Bytes(buf.BinWriter, []byte{1, 0}) emit.Bytes(buf.BinWriter, []byte{1, 0})
emit.Syscall(buf.BinWriter, "Neo.Crypto.SHA256") emit.Syscall(buf.BinWriter, "Neo.Crypto.SHA256")
prog := buf.Bytes() prog := buf.Bytes()
v := vm.New()
ic := &interop.Context{Trigger: trigger.Verification} ic := &interop.Context{Trigger: trigger.Verification}
v.RegisterInteropGetter(GetInterop(ic)) Register(ic)
v := ic.SpawnVM()
v.Load(prog) v.Load(prog)
require.NoError(t, v.Run()) require.NoError(t, v.Run())
assert.Equal(t, 1, v.Estack().Len()) assert.Equal(t, 1, v.Estack().Len())
@ -36,9 +35,9 @@ func TestRIPEMD160(t *testing.T) {
emit.Bytes(buf.BinWriter, []byte{1, 0}) emit.Bytes(buf.BinWriter, []byte{1, 0})
emit.Syscall(buf.BinWriter, "Neo.Crypto.RIPEMD160") emit.Syscall(buf.BinWriter, "Neo.Crypto.RIPEMD160")
prog := buf.Bytes() prog := buf.Bytes()
v := vm.New()
ic := &interop.Context{Trigger: trigger.Verification} ic := &interop.Context{Trigger: trigger.Verification}
v.RegisterInteropGetter(GetInterop(ic)) Register(ic)
v := ic.SpawnVM()
v.Load(prog) v.Load(prog)
require.NoError(t, v.Run()) require.NoError(t, v.Run())
assert.Equal(t, 1, v.Estack().Len()) assert.Equal(t, 1, v.Estack().Len())

View file

@ -2,7 +2,6 @@ package crypto
import ( import (
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
) )
@ -13,36 +12,18 @@ var (
ripemd160ID = emit.InteropNameToID([]byte("Neo.Crypto.RIPEMD160")) ripemd160ID = emit.InteropNameToID([]byte("Neo.Crypto.RIPEMD160"))
) )
// GetInterop returns interop getter for crypto-related stuff. var cryptoInterops = []interop.Function{
func GetInterop(ic *interop.Context) func(uint32) *vm.InteropFuncPrice { {ID: ecdsaSecp256r1VerifyID, Func: ECDSASecp256r1Verify},
return func(id uint32) *vm.InteropFuncPrice { {ID: ecdsaSecp256r1CheckMultisigID, Func: ECDSASecp256r1CheckMultisig},
switch id { {ID: sha256ID, Func: Sha256},
case ecdsaSecp256r1VerifyID: {ID: ripemd160ID, Func: RipeMD160},
return &vm.InteropFuncPrice{ }
Func: func(v *vm.VM) error {
return ECDSASecp256r1Verify(ic, v) func init() {
}, interop.Sort(cryptoInterops)
} }
case ecdsaSecp256r1CheckMultisigID:
return &vm.InteropFuncPrice{ // Register adds crypto interops to ic.
Func: func(v *vm.VM) error { func Register(ic *interop.Context) {
return ECDSASecp256r1CheckMultisig(ic, v) ic.Functions = append(ic.Functions, cryptoInterops)
},
}
case sha256ID:
return &vm.InteropFuncPrice{
Func: func(v *vm.VM) error {
return Sha256(ic, v)
},
}
case ripemd160ID:
return &vm.InteropFuncPrice{
Func: func(v *vm.VM) error {
return RipeMD160(ic, v)
},
}
default:
return nil
}
}
} }

View file

@ -4,7 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -16,19 +16,13 @@ var (
deserializeID = emit.InteropNameToID([]byte("System.Json.Deserialize")) deserializeID = emit.InteropNameToID([]byte("System.Json.Deserialize"))
) )
func getInterop(id uint32) *vm.InteropFuncPrice { var jsonInterops = []interop.Function{
switch id { {ID: serializeID, Func: Serialize},
case serializeID: {ID: deserializeID, Func: Deserialize},
return &vm.InteropFuncPrice{ }
Func: func(v *vm.VM) error { return Serialize(nil, v) },
} func init() {
case deserializeID: interop.Sort(jsonInterops)
return &vm.InteropFuncPrice{
Func: func(v *vm.VM) error { return Deserialize(nil, v) },
}
default:
return nil
}
} }
func getTestFunc(id uint32, arg interface{}, result interface{}) func(t *testing.T) { func getTestFunc(id uint32, arg interface{}, result interface{}) func(t *testing.T) {
@ -37,8 +31,9 @@ func getTestFunc(id uint32, arg interface{}, result interface{}) func(t *testing
binary.LittleEndian.PutUint32(prog[1:], id) binary.LittleEndian.PutUint32(prog[1:], id)
return func(t *testing.T) { return func(t *testing.T) {
v := vm.New() ic := &interop.Context{}
v.RegisterInteropGetter(getInterop) ic.Functions = append(ic.Functions, jsonInterops)
v := ic.SpawnVM()
v.LoadScript(prog) v.LoadScript(prog)
v.Estack().PushVal(arg) v.Estack().PushVal(arg)
if result == nil { if result == nil {

View file

@ -8,8 +8,6 @@ package core
*/ */
import ( import (
"sort"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/interop/crypto" "github.com/nspcc-dev/neo-go/pkg/core/interop/crypto"
"github.com/nspcc-dev/neo-go/pkg/core/interop/enumerator" "github.com/nspcc-dev/neo-go/pkg/core/interop/enumerator"
@ -25,45 +23,11 @@ import (
// SpawnVM returns a VM with script getter and interop functions set // SpawnVM returns a VM with script getter and interop functions set
// up for current blockchain. // up for current blockchain.
func SpawnVM(ic *interop.Context) *vm.VM { func SpawnVM(ic *interop.Context) *vm.VM {
vm := vm.NewWithTrigger(ic.Trigger) vm := ic.SpawnVM()
vm.RegisterInteropGetter(getSystemInterop(ic)) ic.Functions = [][]interop.Function{systemInterops, neoInterops}
vm.RegisterInteropGetter(getNeoInterop(ic))
ic.ScriptGetter = vm
return vm return vm
} }
// getSystemInterop returns matching interop function from the System namespace
// for a given id in the current context.
func getSystemInterop(ic *interop.Context) vm.InteropGetterFunc {
return getInteropFromSlice(ic, systemInterops)
}
// getNeoInterop returns matching interop function from the Neo and AntShares
// namespaces for a given id in the current context.
func getNeoInterop(ic *interop.Context) vm.InteropGetterFunc {
return getInteropFromSlice(ic, neoInterops)
}
// getInteropFromSlice returns matching interop function from the given slice of
// interop functions in the current context.
func getInteropFromSlice(ic *interop.Context, slice []interop.Function) func(uint32) *vm.InteropFuncPrice {
return func(id uint32) *vm.InteropFuncPrice {
n := sort.Search(len(slice), func(i int) bool {
return slice[i].ID >= id
})
if n < len(slice) && slice[n].ID == id {
return &vm.InteropFuncPrice{
Func: func(v *vm.VM) error {
return slice[n].Func(ic, v)
},
Price: slice[n].Price,
RequiredFlags: slice[n].RequiredFlags,
}
}
return nil
}
}
// All lists are sorted, keep 'em this way, please. // All lists are sorted, keep 'em this way, please.
var systemInterops = []interop.Function{ var systemInterops = []interop.Function{
{Name: "System.Binary.Base64Decode", Func: runtimeDecode, Price: 100000}, {Name: "System.Binary.Base64Decode", Func: runtimeDecode, Price: 100000},
@ -136,9 +100,7 @@ func initIDinInteropsSlice(iops []interop.Function) {
for i := range iops { for i := range iops {
iops[i].ID = emit.InteropNameToID([]byte(iops[i].Name)) iops[i].ID = emit.InteropNameToID([]byte(iops[i].Name))
} }
sort.Slice(iops, func(i, j int) bool { interop.Sort(iops)
return iops[i].ID < iops[j].ID
})
} }
// init initializes IDs in the global interop slices. // init initializes IDs in the global interop slices.

View file

@ -110,9 +110,9 @@ func TestParameterContext_AddSignatureMultisig(t *testing.T) {
} }
func newTestVM(w *transaction.Witness, tx *transaction.Transaction) *vm.VM { func newTestVM(w *transaction.Witness, tx *transaction.Transaction) *vm.VM {
v := vm.New() ic := &interop.Context{Container: tx}
v.GasLimit = -1 crypto.Register(ic)
v.RegisterInteropGetter(crypto.GetInterop(&interop.Context{Container: tx})) v := ic.SpawnVM()
v.LoadScript(w.VerificationScript) v.LoadScript(w.VerificationScript)
v.LoadScript(w.InvocationScript) v.LoadScript(w.InvocationScript)
return v return v

View file

@ -10,63 +10,59 @@ import (
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
// InteropFunc allows to hook into the VM. // interopIDFuncPrice adds an ID to the InteropFuncPrice.
type InteropFunc func(vm *VM) error type interopIDFuncPrice struct {
ID uint32
// InteropFuncPrice represents an interop function with a price. Func func(vm *VM) error
type InteropFuncPrice struct {
Func InteropFunc
Price int64 Price int64
RequiredFlags smartcontract.CallFlag RequiredFlags smartcontract.CallFlag
} }
// interopIDFuncPrice adds an ID to the InteropFuncPrice.
type interopIDFuncPrice struct {
ID uint32
InteropFuncPrice
}
// InteropGetterFunc is a function that returns an interop function-price
// structure by the given interop ID.
type InteropGetterFunc func(uint32) *InteropFuncPrice
var defaultVMInterops = []interopIDFuncPrice{ var defaultVMInterops = []interopIDFuncPrice{
{emit.InteropNameToID([]byte("System.Binary.Deserialize")), {ID: emit.InteropNameToID([]byte("System.Binary.Deserialize")),
InteropFuncPrice{Func: RuntimeDeserialize, Price: 500000}}, Func: RuntimeDeserialize, Price: 500000},
{emit.InteropNameToID([]byte("System.Binary.Serialize")), {ID: emit.InteropNameToID([]byte("System.Binary.Serialize")),
InteropFuncPrice{Func: RuntimeSerialize, Price: 100000}}, Func: RuntimeSerialize, Price: 100000},
{emit.InteropNameToID([]byte("System.Runtime.Log")), {ID: emit.InteropNameToID([]byte("System.Runtime.Log")),
InteropFuncPrice{Func: runtimeLog, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}}, Func: runtimeLog, Price: 1000000, RequiredFlags: smartcontract.AllowNotify},
{emit.InteropNameToID([]byte("System.Runtime.Notify")), {ID: emit.InteropNameToID([]byte("System.Runtime.Notify")),
InteropFuncPrice{Func: runtimeNotify, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}}, Func: runtimeNotify, Price: 1000000, RequiredFlags: smartcontract.AllowNotify},
{emit.InteropNameToID([]byte("System.Enumerator.Create")), {ID: emit.InteropNameToID([]byte("System.Enumerator.Create")),
InteropFuncPrice{Func: EnumeratorCreate, Price: 400}}, Func: EnumeratorCreate, Price: 400},
{emit.InteropNameToID([]byte("System.Enumerator.Next")), {ID: emit.InteropNameToID([]byte("System.Enumerator.Next")),
InteropFuncPrice{Func: EnumeratorNext, Price: 1000000}}, Func: EnumeratorNext, Price: 1000000},
{emit.InteropNameToID([]byte("System.Enumerator.Concat")), {ID: emit.InteropNameToID([]byte("System.Enumerator.Concat")),
InteropFuncPrice{Func: EnumeratorConcat, Price: 400}}, Func: EnumeratorConcat, Price: 400},
{emit.InteropNameToID([]byte("System.Enumerator.Value")), {ID: emit.InteropNameToID([]byte("System.Enumerator.Value")),
InteropFuncPrice{Func: EnumeratorValue, Price: 400}}, Func: EnumeratorValue, Price: 400},
{emit.InteropNameToID([]byte("System.Iterator.Create")), {ID: emit.InteropNameToID([]byte("System.Iterator.Create")),
InteropFuncPrice{Func: IteratorCreate, Price: 400}}, Func: IteratorCreate, Price: 400},
{emit.InteropNameToID([]byte("System.Iterator.Concat")), {ID: emit.InteropNameToID([]byte("System.Iterator.Concat")),
InteropFuncPrice{Func: IteratorConcat, Price: 400}}, Func: IteratorConcat, Price: 400},
{emit.InteropNameToID([]byte("System.Iterator.Key")), {ID: emit.InteropNameToID([]byte("System.Iterator.Key")),
InteropFuncPrice{Func: IteratorKey, Price: 400}}, Func: IteratorKey, Price: 400},
{emit.InteropNameToID([]byte("System.Iterator.Keys")), {ID: emit.InteropNameToID([]byte("System.Iterator.Keys")),
InteropFuncPrice{Func: IteratorKeys, Price: 400}}, Func: IteratorKeys, Price: 400},
{emit.InteropNameToID([]byte("System.Iterator.Values")), {ID: emit.InteropNameToID([]byte("System.Iterator.Values")),
InteropFuncPrice{Func: IteratorValues, Price: 400}}, Func: IteratorValues, Price: 400},
} }
func getDefaultVMInterop(id uint32) *InteropFuncPrice { func init() {
sort.Slice(defaultVMInterops, func(i, j int) bool { return defaultVMInterops[i].ID < defaultVMInterops[j].ID })
}
func defaultSyscallHandler(v *VM, id uint32) error {
n := sort.Search(len(defaultVMInterops), func(i int) bool { n := sort.Search(len(defaultVMInterops), func(i int) bool {
return defaultVMInterops[i].ID >= id return defaultVMInterops[i].ID >= id
}) })
if n < len(defaultVMInterops) && defaultVMInterops[n].ID == id { if n >= len(defaultVMInterops) || defaultVMInterops[n].ID != id {
return &defaultVMInterops[n].InteropFuncPrice return errors.New("syscall not found")
} }
return nil d := defaultVMInterops[n]
if !v.Context().callFlag.Has(d.RequiredFlags) {
return fmt.Errorf("missing call flags: %05b vs %05b", v.Context().callFlag, d.RequiredFlags)
}
return d.Func(v)
} }
// runtimeLog handles the syscall "System.Runtime.Log" for printing and logging stuff. // runtimeLog handles the syscall "System.Runtime.Log" for printing and logging stuff.

View file

@ -19,6 +19,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -110,30 +111,21 @@ func TestUT(t *testing.T) {
require.Equal(t, true, testsRan, "neo-vm tests should be available (check submodules)") require.Equal(t, true, testsRan, "neo-vm tests should be available (check submodules)")
} }
func getTestingInterop(id uint32) *InteropFuncPrice { func testSyscallHandler(v *VM, id uint32) error {
f := func(v *VM) error {
v.estack.PushVal(stackitem.NewInterop(new(int)))
return nil
}
switch id { switch id {
case 0x77777777: case 0x77777777:
return &InteropFuncPrice{Func: f} v.Estack().PushVal(stackitem.NewInterop(new(int)))
case 0x66666666: case 0x66666666:
return &InteropFuncPrice{ if !v.Context().callFlag.Has(smartcontract.ReadOnly) {
Func: f, return errors.New("invalid call flags")
RequiredFlags: smartcontract.ReadOnly,
} }
v.Estack().PushVal(stackitem.NewInterop(new(int)))
case 0x55555555: case 0x55555555:
return &InteropFuncPrice{ v.Estack().PushVal(stackitem.NewInterop(new(int)))
Func: f,
}
case 0xADDEADDE: case 0xADDEADDE:
return &InteropFuncPrice{ v.throw(stackitem.Make("error"))
Func: func(v *VM) error { default:
v.throw(stackitem.Make("error")) return errors.New("syscall not found")
return nil
},
}
} }
return nil return nil
} }
@ -163,7 +155,7 @@ func testFile(t *testing.T, filename string) {
prog := []byte(test.Script) prog := []byte(test.Script)
vm := load(prog) vm := load(prog)
vm.state = BreakState vm.state = BreakState
vm.RegisterInteropGetter(getTestingInterop) vm.SyscallHandler = testSyscallHandler
for i := range test.Steps { for i := range test.Steps {
execStep(t, vm, test.Steps[i]) execStep(t, vm, test.Steps[i])

View file

@ -58,13 +58,13 @@ const (
maxSHLArg = stackitem.MaxBigIntegerSizeBits maxSHLArg = stackitem.MaxBigIntegerSizeBits
) )
// SyscallHandler is a type for syscall handler.
type SyscallHandler = func(*VM, uint32) error
// VM represents the virtual machine. // VM represents the virtual machine.
type VM struct { type VM struct {
state State state State
// callbacks to get interops.
getInterop []InteropGetterFunc
// callback to get interop price // callback to get interop price
getPrice func(*VM, opcode.Opcode, []byte) int64 getPrice func(*VM, opcode.Opcode, []byte) int64
@ -78,6 +78,9 @@ type VM struct {
gasConsumed int64 gasConsumed int64
GasLimit int64 GasLimit int64
// SyscallHandler handles SYSCALL opcode.
SyscallHandler func(v *VM, id uint32) error
trigger trigger.Type trigger trigger.Type
// Public keys cache. // Public keys cache.
@ -92,17 +95,16 @@ func New() *VM {
// NewWithTrigger returns a new VM for executions triggered by t. // NewWithTrigger returns a new VM for executions triggered by t.
func NewWithTrigger(t trigger.Type) *VM { func NewWithTrigger(t trigger.Type) *VM {
vm := &VM{ vm := &VM{
getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage. state: HaltState,
state: HaltState, istack: NewStack("invocation"),
istack: NewStack("invocation"), refs: newRefCounter(),
refs: newRefCounter(), keys: make(map[string]*keys.PublicKey),
keys: make(map[string]*keys.PublicKey), trigger: t,
trigger: t,
SyscallHandler: defaultSyscallHandler,
} }
vm.estack = vm.newItemStack("evaluation") vm.estack = vm.newItemStack("evaluation")
vm.RegisterInteropGetter(getDefaultVMInterop)
return vm return vm
} }
@ -113,13 +115,6 @@ func (v *VM) newItemStack(n string) *Stack {
return s return s
} }
// RegisterInteropGetter registers the given InteropGetterFunc into VM. There
// can be many interop getters and they're probed in LIFO order wrt their
// registration time.
func (v *VM) RegisterInteropGetter(f InteropGetterFunc) {
v.getInterop = append(v.getInterop, f)
}
// SetPriceGetter registers the given PriceGetterFunc in v. // SetPriceGetter registers the given PriceGetterFunc in v.
// f accepts vm's Context, current instruction and instruction parameter. // f accepts vm's Context, current instruction and instruction parameter.
func (v *VM) SetPriceGetter(f func(*VM, opcode.Opcode, []byte) int64) { func (v *VM) SetPriceGetter(f func(*VM, opcode.Opcode, []byte) int64) {
@ -494,18 +489,6 @@ func GetInteropID(parameter []byte) uint32 {
return binary.LittleEndian.Uint32(parameter) return binary.LittleEndian.Uint32(parameter)
} }
// GetInteropByID returns interop function together with price.
// Registered callbacks are checked in LIFO order.
func (v *VM) GetInteropByID(id uint32) *InteropFuncPrice {
for i := len(v.getInterop) - 1; i >= 0; i-- {
if ifunc := v.getInterop[i](id); ifunc != nil {
return ifunc
}
}
return nil
}
// execute performs an instruction cycle in the VM. Acting on the instruction (opcode). // execute performs an instruction cycle in the VM. Acting on the instruction (opcode).
func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err error) { func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err error) {
// Instead of polluting the whole VM logic with error handling, we will recover // Instead of polluting the whole VM logic with error handling, we will recover
@ -1250,15 +1233,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.SYSCALL: case opcode.SYSCALL:
interopID := GetInteropID(parameter) interopID := GetInteropID(parameter)
ifunc := v.GetInteropByID(interopID) err := v.SyscallHandler(v, interopID)
if !v.Context().callFlag.Has(ifunc.RequiredFlags) { if err != nil {
panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags))
}
if ifunc == nil {
panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID))
}
if err := ifunc.Func(v); err != nil {
panic(fmt.Sprintf("failed to invoke syscall: %s", err)) panic(fmt.Sprintf("failed to invoke syscall: %s", err))
} }

View file

@ -17,26 +17,25 @@ import (
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func fooInteropGetter(id uint32) *InteropFuncPrice { func fooInteropHandler(v *VM, id uint32) error {
if id == emit.InteropNameToID([]byte("foo")) { if id == emit.InteropNameToID([]byte("foo")) {
return &InteropFuncPrice{ if !v.AddGas(1) {
Func: func(evm *VM) error { return errors.New("invalid gas amount")
evm.Estack().PushVal(1)
return nil
},
Price: 1,
} }
v.Estack().PushVal(1)
return nil
} }
return nil return errors.New("syscall not found")
} }
func TestInteropHook(t *testing.T) { func TestInteropHook(t *testing.T) {
v := newTestVM() v := newTestVM()
v.RegisterInteropGetter(fooInteropGetter) v.SyscallHandler = fooInteropHandler
buf := io.NewBufBinWriter() buf := io.NewBufBinWriter()
emit.Syscall(buf.BinWriter, "foo") emit.Syscall(buf.BinWriter, "foo")
@ -47,13 +46,6 @@ func TestInteropHook(t *testing.T) {
assert.Equal(t, big.NewInt(1), v.estack.Pop().value.Value()) assert.Equal(t, big.NewInt(1), v.estack.Pop().value.Value())
} }
func TestRegisterInteropGetter(t *testing.T) {
v := newTestVM()
currRegistered := len(v.getInterop)
v.RegisterInteropGetter(fooInteropGetter)
assert.Equal(t, currRegistered+1, len(v.getInterop))
}
func TestVM_SetPriceGetter(t *testing.T) { func TestVM_SetPriceGetter(t *testing.T) {
v := newTestVM() v := newTestVM()
prog := []byte{ prog := []byte{
@ -819,7 +811,7 @@ func getTestCallFlagsFunc(syscall []byte, flags smartcontract.CallFlag, result i
return func(t *testing.T) { return func(t *testing.T) {
script := append([]byte{byte(opcode.SYSCALL)}, syscall...) script := append([]byte{byte(opcode.SYSCALL)}, syscall...)
v := newTestVM() v := newTestVM()
v.RegisterInteropGetter(getTestingInterop) v.SyscallHandler = testSyscallHandler
v.LoadScriptWithFlags(script, flags) v.LoadScriptWithFlags(script, flags)
if result == nil { if result == nil {
checkVMFailed(t, v) checkVMFailed(t, v)