From 51ae12e4fd2bfbaf92763abe727dc89e5e99b0c4 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Tue, 28 Jul 2020 16:38:00 +0300 Subject: [PATCH] *: move syscall handling out of VM Remove interop-related structures from the `vm` package. Signed-off-by: Evgenii Stratonikov --- pkg/compiler/panic_test.go | 21 +++--- pkg/compiler/util_test.go | 3 +- pkg/compiler/vm_test.go | 17 +++-- pkg/consensus/payload.go | 7 +- pkg/core/blockchain.go | 10 +-- pkg/core/gas_price.go | 8 --- pkg/core/interop/context.go | 50 +++++++++++++ pkg/core/interop/crypto/ecdsa_test.go | 5 +- pkg/core/interop/crypto/hash_test.go | 9 ++- pkg/core/interop/crypto/interop.go | 47 ++++-------- pkg/core/interop/json/json_test.go | 27 +++---- pkg/core/interops.go | 44 +----------- pkg/smartcontract/context/context_test.go | 6 +- pkg/vm/interop.go | 88 +++++++++++------------ pkg/vm/json_test.go | 30 +++----- pkg/vm/vm.go | 54 ++++---------- pkg/vm/vm_test.go | 26 +++---- 17 files changed, 195 insertions(+), 257 deletions(-) diff --git a/pkg/compiler/panic_test.go b/pkg/compiler/panic_test.go index 0661f75c0..8504f72d4 100644 --- a/pkg/compiler/panic_test.go +++ b/pkg/compiler/panic_test.go @@ -1,6 +1,7 @@ package compiler_test import ( + "errors" "fmt" "math/big" "testing" @@ -20,7 +21,7 @@ func TestPanic(t *testing.T) { var logs []string src := getPanicSource(true, `"execution fault"`) v := vmAndCompile(t, src) - v.RegisterInteropGetter(logGetter(&logs)) + v.SyscallHandler = getLogHandler(&logs) require.Error(t, v.Run()) require.True(t, v.HasFailed()) @@ -32,7 +33,7 @@ func TestPanic(t *testing.T) { var logs []string src := getPanicSource(true, `nil`) v := vmAndCompile(t, src) - v.RegisterInteropGetter(logGetter(&logs)) + v.SyscallHandler = getLogHandler(&logs) require.Error(t, v.Run()) require.True(t, v.HasFailed()) @@ -54,19 +55,15 @@ func getPanicSource(need bool, message string) string { `, need, message) } -func logGetter(logs *[]string) vm.InteropGetterFunc { +func getLogHandler(logs *[]string) vm.SyscallHandler { logID := emit.InteropNameToID([]byte("System.Runtime.Log")) - return func(id uint32) *vm.InteropFuncPrice { + return func(v *vm.VM, id uint32) error { if id != logID { - return nil + return errors.New("syscall not found") } - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - msg := string(v.Estack().Pop().Bytes()) - *logs = append(*logs, msg) - return nil - }, - } + msg := string(v.Estack().Pop().Bytes()) + *logs = append(*logs, msg) + return nil } } diff --git a/pkg/compiler/util_test.go b/pkg/compiler/util_test.go index 59a628642..996a466b8 100644 --- a/pkg/compiler/util_test.go +++ b/pkg/compiler/util_test.go @@ -23,7 +23,8 @@ func TestSHA256(t *testing.T) { ` v := vmAndCompile(t, src) ic := &interop.Context{Trigger: trigger.Verification} - v.RegisterInteropGetter(crypto.GetInterop(ic)) + crypto.Register(ic) + v.SyscallHandler = ic.SyscallHandler require.NoError(t, v.Run()) require.True(t, v.Estack().Len() >= 1) diff --git a/pkg/compiler/vm_test.go b/pkg/compiler/vm_test.go index f2af0d5b8..237ee8c12 100644 --- a/pkg/compiler/vm_test.go +++ b/pkg/compiler/vm_test.go @@ -1,6 +1,7 @@ package compiler_test import ( + "errors" "fmt" "strings" "testing" @@ -68,7 +69,8 @@ func vmAndCompileInterop(t *testing.T, src string) (*vm.VM, *storagePlugin) { vm := vm.New() storePlugin := newStoragePlugin() - vm.RegisterInteropGetter(storePlugin.getInterop) + vm.GasLimit = -1 + vm.SyscallHandler = storePlugin.syscallHandler b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src)) require.NoError(t, err) @@ -97,14 +99,14 @@ func invokeMethod(t *testing.T, method string, script []byte, v *vm.VM, di *comp type storagePlugin struct { mem map[string][]byte - interops map[uint32]vm.InteropFunc + interops map[uint32]func(v *vm.VM) error events []state.NotificationEvent } func newStoragePlugin() *storagePlugin { s := &storagePlugin{ mem: make(map[string][]byte), - interops: make(map[uint32]vm.InteropFunc), + interops: make(map[uint32]func(v *vm.VM) error), } s.interops[emit.InteropNameToID([]byte("System.Storage.Get"))] = s.Get s.interops[emit.InteropNameToID([]byte("System.Storage.Put"))] = s.Put @@ -114,12 +116,15 @@ func newStoragePlugin() *storagePlugin { } -func (s *storagePlugin) getInterop(id uint32) *vm.InteropFuncPrice { +func (s *storagePlugin) syscallHandler(v *vm.VM, id uint32) error { f := s.interops[id] if f != nil { - return &vm.InteropFuncPrice{Func: f, Price: 1} + if !v.AddGas(1) { + return errors.New("insufficient amount of gas") + } + return f(v) } - return nil + return errors.New("syscall not found") } func (s *storagePlugin) Notify(v *vm.VM) error { diff --git a/pkg/consensus/payload.go b/pkg/consensus/payload.go index 5bfde2a75..c61ac2799 100644 --- a/pkg/consensus/payload.go +++ b/pkg/consensus/payload.go @@ -11,8 +11,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/pkg/errors" ) @@ -222,9 +222,10 @@ func (p *Payload) Verify(scriptHash util.Uint160) bool { return false } - v := vm.New() + ic := &interop.Context{Trigger: trigger.Verification, Container: p} + crypto.Register(ic) + v := ic.SpawnVM() v.GasLimit = payloadGasLimit - v.RegisterInteropGetter(crypto.GetInterop(&interop.Context{Container: p})) v.LoadScript(verification) v.LoadScript(p.Witness.InvocationScript) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 96a06ea54..7ac605bed 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -566,8 +566,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { if block.Index > 0 { systemInterop := bc.newInteropContext(trigger.System, cache, block, nil) - v := SpawnVM(systemInterop) - v.GasLimit = -1 + v := systemInterop.SpawnVM() v.LoadScriptWithFlags(bc.contracts.GetPersistScript(), smartcontract.AllowModifyStates|smartcontract.AllowCall) v.SetPriceGetter(getPrice) if err := v.Run(); err != nil { @@ -599,7 +598,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } systemInterop := bc.newInteropContext(trigger.Application, cache, block, tx) - v := SpawnVM(systemInterop) + v := systemInterop.SpawnVM() v.LoadScriptWithFlags(tx.Script, smartcontract.All) v.SetPriceGetter(getPrice) v.GasLimit = tx.SystemFee @@ -1277,7 +1276,7 @@ func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([ // GetTestVM returns a VM and a Store setup for a test run of some sort of code. func (bc *Blockchain) GetTestVM(tx *transaction.Transaction) *vm.VM { systemInterop := bc.newInteropContext(trigger.Application, bc.dao, nil, tx) - vm := SpawnVM(systemInterop) + vm := systemInterop.SpawnVM() vm.SetPriceGetter(getPrice) return vm } @@ -1310,7 +1309,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa gas = gasPolicy } - vm := SpawnVM(interopCtx) + vm := interopCtx.SpawnVM() vm.SetPriceGetter(getPrice) vm.GasLimit = gas vm.LoadScriptWithFlags(verification, smartcontract.ReadOnly) @@ -1413,6 +1412,7 @@ func (bc *Blockchain) secondsPerBlock() int { func (bc *Blockchain) newInteropContext(trigger trigger.Type, d dao.DAO, block *block.Block, tx *transaction.Transaction) *interop.Context { ic := interop.NewContext(trigger, bc, d, bc.contracts.Contracts, block, tx, bc.log) + ic.Functions = [][]interop.Function{systemInterops, neoInterops} switch { case tx != nil: ic.Container = tx diff --git a/pkg/core/gas_price.go b/pkg/core/gas_price.go index 7741206d5..61657d379 100644 --- a/pkg/core/gas_price.go +++ b/pkg/core/gas_price.go @@ -9,14 +9,6 @@ import ( const StoragePrice = 100000 // getPrice returns a price for executing op with the provided parameter. -// Some SYSCALLs have variable price depending on their arguments. func getPrice(v *vm.VM, op opcode.Opcode, parameter []byte) int64 { - if op == opcode.SYSCALL { - interopID := vm.GetInteropID(parameter) - ifunc := v.GetInteropByID(interopID) - if ifunc != nil && ifunc.Price > 0 { - return ifunc.Price - } - } return opcodePrice(op) } diff --git a/pkg/core/interop/context.go b/pkg/core/interop/context.go index 559c92291..fc385cc2a 100644 --- a/pkg/core/interop/context.go +++ b/pkg/core/interop/context.go @@ -1,6 +1,10 @@ package interop import ( + "errors" + "fmt" + "sort" + "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/dao" @@ -32,6 +36,7 @@ type Context struct { Log *zap.Logger Invocations map[util.Uint160]int ScriptGetter vm.ScriptHashGetter + Functions [][]Function } // NewContext returns new interop context. @@ -48,6 +53,8 @@ func NewContext(trigger trigger.Type, bc blockchainer.Blockchainer, d dao.DAO, n Notifications: nes, Log: log, Invocations: make(map[util.Uint160]int), + // Functions is a slice of slices of interops sorted by ID. + Functions: [][]Function{}, } } @@ -124,3 +131,46 @@ func (c *ContractMD) AddEvent(name string, ps ...manifest.Parameter) { Parameters: ps, }) } + +// Sort sorts interop functions by id. +func Sort(fs []Function) { + sort.Slice(fs, func(i, j int) bool { return fs[i].ID < fs[j].ID }) +} + +// GetFunction returns metadata for interop with the specified id. +func (ic *Context) GetFunction(id uint32) *Function { + for _, slice := range ic.Functions { + n := sort.Search(len(slice), func(i int) bool { + return slice[i].ID >= id + }) + if n < len(slice) && slice[n].ID == id { + return &slice[n] + } + } + return nil +} + +// SyscallHandler handles syscall with id. +func (ic *Context) SyscallHandler(v *vm.VM, id uint32) error { + f := ic.GetFunction(id) + if f == nil { + return errors.New("syscall not found") + } + cf := v.Context().GetCallFlags() + if !cf.Has(f.RequiredFlags) { + return fmt.Errorf("missing call flags: %05b vs %05b", cf, f.RequiredFlags) + } + if !v.AddGas(f.Price) { + return errors.New("insufficient amount of gas") + } + return f.Func(ic, v) +} + +// SpawnVM spawns new VM with the specified gas limit. +func (ic *Context) SpawnVM() *vm.VM { + v := vm.NewWithTrigger(ic.Trigger) + v.GasLimit = -1 + v.SyscallHandler = ic.SyscallHandler + ic.ScriptGetter = v + return v +} diff --git a/pkg/core/interop/crypto/ecdsa_test.go b/pkg/core/interop/crypto/ecdsa_test.go index 4dd735e87..c2546c81c 100644 --- a/pkg/core/interop/crypto/ecdsa_test.go +++ b/pkg/core/interop/crypto/ecdsa_test.go @@ -59,10 +59,9 @@ func initCHECKMULTISIGVM(t *testing.T, n int, ik, is []int) *vm.VM { buf[0] = byte(opcode.SYSCALL) binary.LittleEndian.PutUint32(buf[1:], ecdsaSecp256r1CheckMultisigID) - v := vm.New() - v.GasLimit = -1 ic := &interop.Context{Trigger: trigger.Verification} - v.RegisterInteropGetter(GetInterop(ic)) + Register(ic) + v := ic.SpawnVM() v.LoadScript(buf) msg := []byte("NEO - An Open Network For Smart Economy") diff --git a/pkg/core/interop/crypto/hash_test.go b/pkg/core/interop/crypto/hash_test.go index 566bcfb6e..9f4bdc22c 100644 --- a/pkg/core/interop/crypto/hash_test.go +++ b/pkg/core/interop/crypto/hash_test.go @@ -7,7 +7,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,9 +19,9 @@ func TestSHA256(t *testing.T) { emit.Bytes(buf.BinWriter, []byte{1, 0}) emit.Syscall(buf.BinWriter, "Neo.Crypto.SHA256") prog := buf.Bytes() - v := vm.New() ic := &interop.Context{Trigger: trigger.Verification} - v.RegisterInteropGetter(GetInterop(ic)) + Register(ic) + v := ic.SpawnVM() v.Load(prog) require.NoError(t, v.Run()) assert.Equal(t, 1, v.Estack().Len()) @@ -36,9 +35,9 @@ func TestRIPEMD160(t *testing.T) { emit.Bytes(buf.BinWriter, []byte{1, 0}) emit.Syscall(buf.BinWriter, "Neo.Crypto.RIPEMD160") prog := buf.Bytes() - v := vm.New() ic := &interop.Context{Trigger: trigger.Verification} - v.RegisterInteropGetter(GetInterop(ic)) + Register(ic) + v := ic.SpawnVM() v.Load(prog) require.NoError(t, v.Run()) assert.Equal(t, 1, v.Estack().Len()) diff --git a/pkg/core/interop/crypto/interop.go b/pkg/core/interop/crypto/interop.go index aff3410d4..b680b69d8 100644 --- a/pkg/core/interop/crypto/interop.go +++ b/pkg/core/interop/crypto/interop.go @@ -2,7 +2,6 @@ package crypto import ( "github.com/nspcc-dev/neo-go/pkg/core/interop" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" ) @@ -13,36 +12,18 @@ var ( ripemd160ID = emit.InteropNameToID([]byte("Neo.Crypto.RIPEMD160")) ) -// GetInterop returns interop getter for crypto-related stuff. -func GetInterop(ic *interop.Context) func(uint32) *vm.InteropFuncPrice { - return func(id uint32) *vm.InteropFuncPrice { - switch id { - case ecdsaSecp256r1VerifyID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return ECDSASecp256r1Verify(ic, v) - }, - } - case ecdsaSecp256r1CheckMultisigID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return ECDSASecp256r1CheckMultisig(ic, v) - }, - } - case sha256ID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return Sha256(ic, v) - }, - } - case ripemd160ID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return RipeMD160(ic, v) - }, - } - default: - return nil - } - } +var cryptoInterops = []interop.Function{ + {ID: ecdsaSecp256r1VerifyID, Func: ECDSASecp256r1Verify}, + {ID: ecdsaSecp256r1CheckMultisigID, Func: ECDSASecp256r1CheckMultisig}, + {ID: sha256ID, Func: Sha256}, + {ID: ripemd160ID, Func: RipeMD160}, +} + +func init() { + interop.Sort(cryptoInterops) +} + +// Register adds crypto interops to ic. +func Register(ic *interop.Context) { + ic.Functions = append(ic.Functions, cryptoInterops) } diff --git a/pkg/core/interop/json/json_test.go b/pkg/core/interop/json/json_test.go index b1d4bb34f..12a86edcc 100644 --- a/pkg/core/interop/json/json_test.go +++ b/pkg/core/interop/json/json_test.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "testing" - "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -16,19 +16,13 @@ var ( deserializeID = emit.InteropNameToID([]byte("System.Json.Deserialize")) ) -func getInterop(id uint32) *vm.InteropFuncPrice { - switch id { - case serializeID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { return Serialize(nil, v) }, - } - case deserializeID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { return Deserialize(nil, v) }, - } - default: - return nil - } +var jsonInterops = []interop.Function{ + {ID: serializeID, Func: Serialize}, + {ID: deserializeID, Func: Deserialize}, +} + +func init() { + interop.Sort(jsonInterops) } func getTestFunc(id uint32, arg interface{}, result interface{}) func(t *testing.T) { @@ -37,8 +31,9 @@ func getTestFunc(id uint32, arg interface{}, result interface{}) func(t *testing binary.LittleEndian.PutUint32(prog[1:], id) return func(t *testing.T) { - v := vm.New() - v.RegisterInteropGetter(getInterop) + ic := &interop.Context{} + ic.Functions = append(ic.Functions, jsonInterops) + v := ic.SpawnVM() v.LoadScript(prog) v.Estack().PushVal(arg) if result == nil { diff --git a/pkg/core/interops.go b/pkg/core/interops.go index 2ee8d780f..778f4974f 100644 --- a/pkg/core/interops.go +++ b/pkg/core/interops.go @@ -8,8 +8,6 @@ package core */ import ( - "sort" - "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop/crypto" "github.com/nspcc-dev/neo-go/pkg/core/interop/enumerator" @@ -25,45 +23,11 @@ import ( // SpawnVM returns a VM with script getter and interop functions set // up for current blockchain. func SpawnVM(ic *interop.Context) *vm.VM { - vm := vm.NewWithTrigger(ic.Trigger) - vm.RegisterInteropGetter(getSystemInterop(ic)) - vm.RegisterInteropGetter(getNeoInterop(ic)) - ic.ScriptGetter = vm + vm := ic.SpawnVM() + ic.Functions = [][]interop.Function{systemInterops, neoInterops} return vm } -// getSystemInterop returns matching interop function from the System namespace -// for a given id in the current context. -func getSystemInterop(ic *interop.Context) vm.InteropGetterFunc { - return getInteropFromSlice(ic, systemInterops) -} - -// getNeoInterop returns matching interop function from the Neo and AntShares -// namespaces for a given id in the current context. -func getNeoInterop(ic *interop.Context) vm.InteropGetterFunc { - return getInteropFromSlice(ic, neoInterops) -} - -// getInteropFromSlice returns matching interop function from the given slice of -// interop functions in the current context. -func getInteropFromSlice(ic *interop.Context, slice []interop.Function) func(uint32) *vm.InteropFuncPrice { - return func(id uint32) *vm.InteropFuncPrice { - n := sort.Search(len(slice), func(i int) bool { - return slice[i].ID >= id - }) - if n < len(slice) && slice[n].ID == id { - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return slice[n].Func(ic, v) - }, - Price: slice[n].Price, - RequiredFlags: slice[n].RequiredFlags, - } - } - return nil - } -} - // All lists are sorted, keep 'em this way, please. var systemInterops = []interop.Function{ {Name: "System.Binary.Base64Decode", Func: runtimeDecode, Price: 100000}, @@ -136,9 +100,7 @@ func initIDinInteropsSlice(iops []interop.Function) { for i := range iops { iops[i].ID = emit.InteropNameToID([]byte(iops[i].Name)) } - sort.Slice(iops, func(i, j int) bool { - return iops[i].ID < iops[j].ID - }) + interop.Sort(iops) } // init initializes IDs in the global interop slices. diff --git a/pkg/smartcontract/context/context_test.go b/pkg/smartcontract/context/context_test.go index 5ae5e836d..309135e79 100644 --- a/pkg/smartcontract/context/context_test.go +++ b/pkg/smartcontract/context/context_test.go @@ -110,9 +110,9 @@ func TestParameterContext_AddSignatureMultisig(t *testing.T) { } func newTestVM(w *transaction.Witness, tx *transaction.Transaction) *vm.VM { - v := vm.New() - v.GasLimit = -1 - v.RegisterInteropGetter(crypto.GetInterop(&interop.Context{Container: tx})) + ic := &interop.Context{Container: tx} + crypto.Register(ic) + v := ic.SpawnVM() v.LoadScript(w.VerificationScript) v.LoadScript(w.InvocationScript) return v diff --git a/pkg/vm/interop.go b/pkg/vm/interop.go index 218d93332..bf04a2e8e 100644 --- a/pkg/vm/interop.go +++ b/pkg/vm/interop.go @@ -10,63 +10,59 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) -// InteropFunc allows to hook into the VM. -type InteropFunc func(vm *VM) error - -// InteropFuncPrice represents an interop function with a price. -type InteropFuncPrice struct { - Func InteropFunc +// interopIDFuncPrice adds an ID to the InteropFuncPrice. +type interopIDFuncPrice struct { + ID uint32 + Func func(vm *VM) error Price int64 RequiredFlags smartcontract.CallFlag } -// interopIDFuncPrice adds an ID to the InteropFuncPrice. -type interopIDFuncPrice struct { - ID uint32 - InteropFuncPrice -} - -// InteropGetterFunc is a function that returns an interop function-price -// structure by the given interop ID. -type InteropGetterFunc func(uint32) *InteropFuncPrice - var defaultVMInterops = []interopIDFuncPrice{ - {emit.InteropNameToID([]byte("System.Binary.Deserialize")), - InteropFuncPrice{Func: RuntimeDeserialize, Price: 500000}}, - {emit.InteropNameToID([]byte("System.Binary.Serialize")), - InteropFuncPrice{Func: RuntimeSerialize, Price: 100000}}, - {emit.InteropNameToID([]byte("System.Runtime.Log")), - InteropFuncPrice{Func: runtimeLog, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}}, - {emit.InteropNameToID([]byte("System.Runtime.Notify")), - InteropFuncPrice{Func: runtimeNotify, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}}, - {emit.InteropNameToID([]byte("System.Enumerator.Create")), - InteropFuncPrice{Func: EnumeratorCreate, Price: 400}}, - {emit.InteropNameToID([]byte("System.Enumerator.Next")), - InteropFuncPrice{Func: EnumeratorNext, Price: 1000000}}, - {emit.InteropNameToID([]byte("System.Enumerator.Concat")), - InteropFuncPrice{Func: EnumeratorConcat, Price: 400}}, - {emit.InteropNameToID([]byte("System.Enumerator.Value")), - InteropFuncPrice{Func: EnumeratorValue, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Create")), - InteropFuncPrice{Func: IteratorCreate, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Concat")), - InteropFuncPrice{Func: IteratorConcat, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Key")), - InteropFuncPrice{Func: IteratorKey, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Keys")), - InteropFuncPrice{Func: IteratorKeys, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Values")), - InteropFuncPrice{Func: IteratorValues, Price: 400}}, + {ID: emit.InteropNameToID([]byte("System.Binary.Deserialize")), + Func: RuntimeDeserialize, Price: 500000}, + {ID: emit.InteropNameToID([]byte("System.Binary.Serialize")), + Func: RuntimeSerialize, Price: 100000}, + {ID: emit.InteropNameToID([]byte("System.Runtime.Log")), + Func: runtimeLog, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}, + {ID: emit.InteropNameToID([]byte("System.Runtime.Notify")), + Func: runtimeNotify, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}, + {ID: emit.InteropNameToID([]byte("System.Enumerator.Create")), + Func: EnumeratorCreate, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Enumerator.Next")), + Func: EnumeratorNext, Price: 1000000}, + {ID: emit.InteropNameToID([]byte("System.Enumerator.Concat")), + Func: EnumeratorConcat, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Enumerator.Value")), + Func: EnumeratorValue, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Create")), + Func: IteratorCreate, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Concat")), + Func: IteratorConcat, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Key")), + Func: IteratorKey, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Keys")), + Func: IteratorKeys, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Values")), + Func: IteratorValues, Price: 400}, } -func getDefaultVMInterop(id uint32) *InteropFuncPrice { +func init() { + sort.Slice(defaultVMInterops, func(i, j int) bool { return defaultVMInterops[i].ID < defaultVMInterops[j].ID }) +} + +func defaultSyscallHandler(v *VM, id uint32) error { n := sort.Search(len(defaultVMInterops), func(i int) bool { return defaultVMInterops[i].ID >= id }) - if n < len(defaultVMInterops) && defaultVMInterops[n].ID == id { - return &defaultVMInterops[n].InteropFuncPrice + if n >= len(defaultVMInterops) || defaultVMInterops[n].ID != id { + return errors.New("syscall not found") } - return nil + d := defaultVMInterops[n] + if !v.Context().callFlag.Has(d.RequiredFlags) { + return fmt.Errorf("missing call flags: %05b vs %05b", v.Context().callFlag, d.RequiredFlags) + } + return d.Func(v) } // runtimeLog handles the syscall "System.Runtime.Log" for printing and logging stuff. diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index 78633cc99..7a86f3a70 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -19,6 +19,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/pkg/errors" "github.com/stretchr/testify/require" ) @@ -110,30 +111,21 @@ func TestUT(t *testing.T) { require.Equal(t, true, testsRan, "neo-vm tests should be available (check submodules)") } -func getTestingInterop(id uint32) *InteropFuncPrice { - f := func(v *VM) error { - v.estack.PushVal(stackitem.NewInterop(new(int))) - return nil - } +func testSyscallHandler(v *VM, id uint32) error { switch id { case 0x77777777: - return &InteropFuncPrice{Func: f} + v.Estack().PushVal(stackitem.NewInterop(new(int))) case 0x66666666: - return &InteropFuncPrice{ - Func: f, - RequiredFlags: smartcontract.ReadOnly, + if !v.Context().callFlag.Has(smartcontract.ReadOnly) { + return errors.New("invalid call flags") } + v.Estack().PushVal(stackitem.NewInterop(new(int))) case 0x55555555: - return &InteropFuncPrice{ - Func: f, - } + v.Estack().PushVal(stackitem.NewInterop(new(int))) case 0xADDEADDE: - return &InteropFuncPrice{ - Func: func(v *VM) error { - v.throw(stackitem.Make("error")) - return nil - }, - } + v.throw(stackitem.Make("error")) + default: + return errors.New("syscall not found") } return nil } @@ -163,7 +155,7 @@ func testFile(t *testing.T, filename string) { prog := []byte(test.Script) vm := load(prog) vm.state = BreakState - vm.RegisterInteropGetter(getTestingInterop) + vm.SyscallHandler = testSyscallHandler for i := range test.Steps { execStep(t, vm, test.Steps[i]) diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 0ad2a65f6..dcef4e0e5 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -58,13 +58,13 @@ const ( maxSHLArg = stackitem.MaxBigIntegerSizeBits ) +// SyscallHandler is a type for syscall handler. +type SyscallHandler = func(*VM, uint32) error + // VM represents the virtual machine. type VM struct { state State - // callbacks to get interops. - getInterop []InteropGetterFunc - // callback to get interop price getPrice func(*VM, opcode.Opcode, []byte) int64 @@ -78,6 +78,9 @@ type VM struct { gasConsumed int64 GasLimit int64 + // SyscallHandler handles SYSCALL opcode. + SyscallHandler func(v *VM, id uint32) error + trigger trigger.Type // Public keys cache. @@ -92,17 +95,16 @@ func New() *VM { // NewWithTrigger returns a new VM for executions triggered by t. func NewWithTrigger(t trigger.Type) *VM { vm := &VM{ - getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage. - state: HaltState, - istack: NewStack("invocation"), - refs: newRefCounter(), - keys: make(map[string]*keys.PublicKey), - trigger: t, + state: HaltState, + istack: NewStack("invocation"), + refs: newRefCounter(), + keys: make(map[string]*keys.PublicKey), + trigger: t, + + SyscallHandler: defaultSyscallHandler, } vm.estack = vm.newItemStack("evaluation") - - vm.RegisterInteropGetter(getDefaultVMInterop) return vm } @@ -113,13 +115,6 @@ func (v *VM) newItemStack(n string) *Stack { return s } -// RegisterInteropGetter registers the given InteropGetterFunc into VM. There -// can be many interop getters and they're probed in LIFO order wrt their -// registration time. -func (v *VM) RegisterInteropGetter(f InteropGetterFunc) { - v.getInterop = append(v.getInterop, f) -} - // SetPriceGetter registers the given PriceGetterFunc in v. // f accepts vm's Context, current instruction and instruction parameter. func (v *VM) SetPriceGetter(f func(*VM, opcode.Opcode, []byte) int64) { @@ -494,18 +489,6 @@ func GetInteropID(parameter []byte) uint32 { return binary.LittleEndian.Uint32(parameter) } -// GetInteropByID returns interop function together with price. -// Registered callbacks are checked in LIFO order. -func (v *VM) GetInteropByID(id uint32) *InteropFuncPrice { - for i := len(v.getInterop) - 1; i >= 0; i-- { - if ifunc := v.getInterop[i](id); ifunc != nil { - return ifunc - } - } - - return nil -} - // execute performs an instruction cycle in the VM. Acting on the instruction (opcode). func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err error) { // Instead of polluting the whole VM logic with error handling, we will recover @@ -1250,15 +1233,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.SYSCALL: interopID := GetInteropID(parameter) - ifunc := v.GetInteropByID(interopID) - if !v.Context().callFlag.Has(ifunc.RequiredFlags) { - panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags)) - } - - if ifunc == nil { - panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID)) - } - if err := ifunc.Func(v); err != nil { + err := v.SyscallHandler(v, interopID) + if err != nil { panic(fmt.Sprintf("failed to invoke syscall: %s", err)) } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index f4a0a94ec..2c872cd94 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -17,26 +17,25 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func fooInteropGetter(id uint32) *InteropFuncPrice { +func fooInteropHandler(v *VM, id uint32) error { if id == emit.InteropNameToID([]byte("foo")) { - return &InteropFuncPrice{ - Func: func(evm *VM) error { - evm.Estack().PushVal(1) - return nil - }, - Price: 1, + if !v.AddGas(1) { + return errors.New("invalid gas amount") } + v.Estack().PushVal(1) + return nil } - return nil + return errors.New("syscall not found") } func TestInteropHook(t *testing.T) { v := newTestVM() - v.RegisterInteropGetter(fooInteropGetter) + v.SyscallHandler = fooInteropHandler buf := io.NewBufBinWriter() emit.Syscall(buf.BinWriter, "foo") @@ -47,13 +46,6 @@ func TestInteropHook(t *testing.T) { assert.Equal(t, big.NewInt(1), v.estack.Pop().value.Value()) } -func TestRegisterInteropGetter(t *testing.T) { - v := newTestVM() - currRegistered := len(v.getInterop) - v.RegisterInteropGetter(fooInteropGetter) - assert.Equal(t, currRegistered+1, len(v.getInterop)) -} - func TestVM_SetPriceGetter(t *testing.T) { v := newTestVM() prog := []byte{ @@ -819,7 +811,7 @@ func getTestCallFlagsFunc(syscall []byte, flags smartcontract.CallFlag, result i return func(t *testing.T) { script := append([]byte{byte(opcode.SYSCALL)}, syscall...) v := newTestVM() - v.RegisterInteropGetter(getTestingInterop) + v.SyscallHandler = testSyscallHandler v.LoadScriptWithFlags(script, flags) if result == nil { checkVMFailed(t, v)