diff --git a/pkg/compiler/panic_test.go b/pkg/compiler/panic_test.go index 0661f75c0..8504f72d4 100644 --- a/pkg/compiler/panic_test.go +++ b/pkg/compiler/panic_test.go @@ -1,6 +1,7 @@ package compiler_test import ( + "errors" "fmt" "math/big" "testing" @@ -20,7 +21,7 @@ func TestPanic(t *testing.T) { var logs []string src := getPanicSource(true, `"execution fault"`) v := vmAndCompile(t, src) - v.RegisterInteropGetter(logGetter(&logs)) + v.SyscallHandler = getLogHandler(&logs) require.Error(t, v.Run()) require.True(t, v.HasFailed()) @@ -32,7 +33,7 @@ func TestPanic(t *testing.T) { var logs []string src := getPanicSource(true, `nil`) v := vmAndCompile(t, src) - v.RegisterInteropGetter(logGetter(&logs)) + v.SyscallHandler = getLogHandler(&logs) require.Error(t, v.Run()) require.True(t, v.HasFailed()) @@ -54,19 +55,15 @@ func getPanicSource(need bool, message string) string { `, need, message) } -func logGetter(logs *[]string) vm.InteropGetterFunc { +func getLogHandler(logs *[]string) vm.SyscallHandler { logID := emit.InteropNameToID([]byte("System.Runtime.Log")) - return func(id uint32) *vm.InteropFuncPrice { + return func(v *vm.VM, id uint32) error { if id != logID { - return nil + return errors.New("syscall not found") } - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - msg := string(v.Estack().Pop().Bytes()) - *logs = append(*logs, msg) - return nil - }, - } + msg := string(v.Estack().Pop().Bytes()) + *logs = append(*logs, msg) + return nil } } diff --git a/pkg/compiler/util_test.go b/pkg/compiler/util_test.go index 59a628642..996a466b8 100644 --- a/pkg/compiler/util_test.go +++ b/pkg/compiler/util_test.go @@ -23,7 +23,8 @@ func TestSHA256(t *testing.T) { ` v := vmAndCompile(t, src) ic := &interop.Context{Trigger: trigger.Verification} - v.RegisterInteropGetter(crypto.GetInterop(ic)) + crypto.Register(ic) + v.SyscallHandler = ic.SyscallHandler require.NoError(t, v.Run()) require.True(t, v.Estack().Len() >= 1) diff --git a/pkg/compiler/vm_test.go b/pkg/compiler/vm_test.go index f2af0d5b8..237ee8c12 100644 --- a/pkg/compiler/vm_test.go +++ b/pkg/compiler/vm_test.go @@ -1,6 +1,7 @@ package compiler_test import ( + "errors" "fmt" "strings" "testing" @@ -68,7 +69,8 @@ func vmAndCompileInterop(t *testing.T, src string) (*vm.VM, *storagePlugin) { vm := vm.New() storePlugin := newStoragePlugin() - vm.RegisterInteropGetter(storePlugin.getInterop) + vm.GasLimit = -1 + vm.SyscallHandler = storePlugin.syscallHandler b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src)) require.NoError(t, err) @@ -97,14 +99,14 @@ func invokeMethod(t *testing.T, method string, script []byte, v *vm.VM, di *comp type storagePlugin struct { mem map[string][]byte - interops map[uint32]vm.InteropFunc + interops map[uint32]func(v *vm.VM) error events []state.NotificationEvent } func newStoragePlugin() *storagePlugin { s := &storagePlugin{ mem: make(map[string][]byte), - interops: make(map[uint32]vm.InteropFunc), + interops: make(map[uint32]func(v *vm.VM) error), } s.interops[emit.InteropNameToID([]byte("System.Storage.Get"))] = s.Get s.interops[emit.InteropNameToID([]byte("System.Storage.Put"))] = s.Put @@ -114,12 +116,15 @@ func newStoragePlugin() *storagePlugin { } -func (s *storagePlugin) getInterop(id uint32) *vm.InteropFuncPrice { +func (s *storagePlugin) syscallHandler(v *vm.VM, id uint32) error { f := s.interops[id] if f != nil { - return &vm.InteropFuncPrice{Func: f, Price: 1} + if !v.AddGas(1) { + return errors.New("insufficient amount of gas") + } + return f(v) } - return nil + return errors.New("syscall not found") } func (s *storagePlugin) Notify(v *vm.VM) error { diff --git a/pkg/consensus/payload.go b/pkg/consensus/payload.go index 5bfde2a75..c61ac2799 100644 --- a/pkg/consensus/payload.go +++ b/pkg/consensus/payload.go @@ -11,8 +11,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/pkg/errors" ) @@ -222,9 +222,10 @@ func (p *Payload) Verify(scriptHash util.Uint160) bool { return false } - v := vm.New() + ic := &interop.Context{Trigger: trigger.Verification, Container: p} + crypto.Register(ic) + v := ic.SpawnVM() v.GasLimit = payloadGasLimit - v.RegisterInteropGetter(crypto.GetInterop(&interop.Context{Container: p})) v.LoadScript(verification) v.LoadScript(p.Witness.InvocationScript) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 96a06ea54..7ac605bed 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -566,8 +566,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { if block.Index > 0 { systemInterop := bc.newInteropContext(trigger.System, cache, block, nil) - v := SpawnVM(systemInterop) - v.GasLimit = -1 + v := systemInterop.SpawnVM() v.LoadScriptWithFlags(bc.contracts.GetPersistScript(), smartcontract.AllowModifyStates|smartcontract.AllowCall) v.SetPriceGetter(getPrice) if err := v.Run(); err != nil { @@ -599,7 +598,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } systemInterop := bc.newInteropContext(trigger.Application, cache, block, tx) - v := SpawnVM(systemInterop) + v := systemInterop.SpawnVM() v.LoadScriptWithFlags(tx.Script, smartcontract.All) v.SetPriceGetter(getPrice) v.GasLimit = tx.SystemFee @@ -1277,7 +1276,7 @@ func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([ // GetTestVM returns a VM and a Store setup for a test run of some sort of code. func (bc *Blockchain) GetTestVM(tx *transaction.Transaction) *vm.VM { systemInterop := bc.newInteropContext(trigger.Application, bc.dao, nil, tx) - vm := SpawnVM(systemInterop) + vm := systemInterop.SpawnVM() vm.SetPriceGetter(getPrice) return vm } @@ -1310,7 +1309,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa gas = gasPolicy } - vm := SpawnVM(interopCtx) + vm := interopCtx.SpawnVM() vm.SetPriceGetter(getPrice) vm.GasLimit = gas vm.LoadScriptWithFlags(verification, smartcontract.ReadOnly) @@ -1413,6 +1412,7 @@ func (bc *Blockchain) secondsPerBlock() int { func (bc *Blockchain) newInteropContext(trigger trigger.Type, d dao.DAO, block *block.Block, tx *transaction.Transaction) *interop.Context { ic := interop.NewContext(trigger, bc, d, bc.contracts.Contracts, block, tx, bc.log) + ic.Functions = [][]interop.Function{systemInterops, neoInterops} switch { case tx != nil: ic.Container = tx diff --git a/pkg/core/gas_price.go b/pkg/core/gas_price.go index 7741206d5..61657d379 100644 --- a/pkg/core/gas_price.go +++ b/pkg/core/gas_price.go @@ -9,14 +9,6 @@ import ( const StoragePrice = 100000 // getPrice returns a price for executing op with the provided parameter. -// Some SYSCALLs have variable price depending on their arguments. func getPrice(v *vm.VM, op opcode.Opcode, parameter []byte) int64 { - if op == opcode.SYSCALL { - interopID := vm.GetInteropID(parameter) - ifunc := v.GetInteropByID(interopID) - if ifunc != nil && ifunc.Price > 0 { - return ifunc.Price - } - } return opcodePrice(op) } diff --git a/pkg/core/interop/callback/callback.go b/pkg/core/interop/callback/callback.go new file mode 100644 index 000000000..e32afd2bd --- /dev/null +++ b/pkg/core/interop/callback/callback.go @@ -0,0 +1,37 @@ +package callback + +import ( + "errors" + + "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/emit" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) + +// Callback is an interface for arbitrary callbacks. +type Callback interface { + // ArgCount returns expected number of arguments. + ArgCount() int + // LoadContext loads context and arguments on stack. + LoadContext(*vm.VM, []stackitem.Item) +} + +// Invoke invokes provided callback. +func Invoke(ic *interop.Context, v *vm.VM) error { + cb := v.Estack().Pop().Interop().Value().(Callback) + args := v.Estack().Pop().Array() + if cb.ArgCount() != len(args) { + return errors.New("invalid argument count") + } + cb.LoadContext(v, args) + switch t := cb.(type) { + case *MethodCallback: + id := emit.InteropNameToID([]byte("System.Contract.Call")) + return ic.SyscallHandler(v, id) + case *SyscallCallback: + return ic.SyscallHandler(v, t.desc.ID) + default: + return nil + } +} diff --git a/pkg/core/interop/callback/method.go b/pkg/core/interop/callback/method.go new file mode 100644 index 000000000..4eade208d --- /dev/null +++ b/pkg/core/interop/callback/method.go @@ -0,0 +1,60 @@ +package callback + +import ( + "errors" + "strings" + + "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) + +// MethodCallback represents callback for contract method. +type MethodCallback struct { + contract *state.Contract + method *manifest.Method +} + +var _ Callback = (*MethodCallback)(nil) + +// ArgCount implements Callback interface. +func (s *MethodCallback) ArgCount() int { + return len(s.method.Parameters) +} + +// LoadContext implements Callback interface. +func (s *MethodCallback) LoadContext(v *vm.VM, args []stackitem.Item) { + v.Estack().PushVal(args) + v.Estack().PushVal(s.method.Name) + v.Estack().PushVal(s.contract.ScriptHash().BytesBE()) +} + +// CreateFromMethod creates callback for a contract method. +func CreateFromMethod(ic *interop.Context, v *vm.VM) error { + rawHash := v.Estack().Pop().Bytes() + h, err := util.Uint160DecodeBytesBE(rawHash) + if err != nil { + return err + } + cs, err := ic.DAO.GetContractState(h) + if err != nil { + return errors.New("contract not found") + } + method := string(v.Estack().Pop().Bytes()) + if strings.HasPrefix(method, "_") { + return errors.New("invalid method name") + } + currCs, err := ic.DAO.GetContractState(v.GetCurrentScriptHash()) + if err == nil && !currCs.Manifest.CanCall(&cs.Manifest, method) { + return errors.New("method call is not allowed") + } + md := cs.Manifest.ABI.GetMethod(method) + v.Estack().PushVal(stackitem.NewInterop(&MethodCallback{ + contract: cs, + method: md, + })) + return nil +} diff --git a/pkg/core/interop/callback/pointer.go b/pkg/core/interop/callback/pointer.go new file mode 100644 index 000000000..87963ae6a --- /dev/null +++ b/pkg/core/interop/callback/pointer.go @@ -0,0 +1,42 @@ +package callback + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) + +// PointerCallback represents callback for a pointer. +type PointerCallback struct { + paramCount int + offset int + context *vm.Context +} + +var _ Callback = (*PointerCallback)(nil) + +// ArgCount implements Callback interface. +func (p *PointerCallback) ArgCount() int { + return p.paramCount +} + +// LoadContext implements Callback interface. +func (p *PointerCallback) LoadContext(v *vm.VM, args []stackitem.Item) { + v.Call(p.context, p.offset) + for i := len(args) - 1; i >= 0; i-- { + v.Estack().PushVal(args[i]) + } +} + +// Create creates callback using pointer and parameters count. +func Create(_ *interop.Context, v *vm.VM) error { + ctx := v.Estack().Pop().Item().(*vm.Context) + offset := v.Estack().Pop().Item().(*stackitem.Pointer).Position() + count := v.Estack().Pop().BigInt().Int64() + v.Estack().PushVal(stackitem.NewInterop(&PointerCallback{ + paramCount: int(count), + offset: offset, + context: ctx, + })) + return nil +} diff --git a/pkg/core/interop/callback/syscall.go b/pkg/core/interop/callback/syscall.go new file mode 100644 index 000000000..d52c091d3 --- /dev/null +++ b/pkg/core/interop/callback/syscall.go @@ -0,0 +1,42 @@ +package callback + +import ( + "errors" + + "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) + +// SyscallCallback represents callback for a syscall. +type SyscallCallback struct { + desc *interop.Function +} + +var _ Callback = (*SyscallCallback)(nil) + +// ArgCount implements Callback interface. +func (p *SyscallCallback) ArgCount() int { + return p.desc.ParamCount +} + +// LoadContext implements Callback interface. +func (p *SyscallCallback) LoadContext(v *vm.VM, args []stackitem.Item) { + for i := len(args) - 1; i >= 0; i-- { + v.Estack().PushVal(args[i]) + } +} + +// CreateFromSyscall creates callback from syscall. +func CreateFromSyscall(ic *interop.Context, v *vm.VM) error { + id := uint32(v.Estack().Pop().BigInt().Int64()) + f := ic.GetFunction(id) + if f == nil { + return errors.New("syscall not found") + } + if f.DisallowCallback { + return errors.New("syscall is not allowed to be used in a callback") + } + v.Estack().PushVal(stackitem.NewInterop(&SyscallCallback{f})) + return nil +} diff --git a/pkg/core/interop/context.go b/pkg/core/interop/context.go index 559c92291..c35bf7936 100644 --- a/pkg/core/interop/context.go +++ b/pkg/core/interop/context.go @@ -1,6 +1,10 @@ package interop import ( + "errors" + "fmt" + "sort" + "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/dao" @@ -32,6 +36,7 @@ type Context struct { Log *zap.Logger Invocations map[util.Uint160]int ScriptGetter vm.ScriptHashGetter + Functions [][]Function } // NewContext returns new interop context. @@ -48,6 +53,8 @@ func NewContext(trigger trigger.Type, bc blockchainer.Blockchainer, d dao.DAO, n Notifications: nes, Log: log, Invocations: make(map[util.Uint160]int), + // Functions is a slice of slices of interops sorted by ID. + Functions: [][]Function{}, } } @@ -55,10 +62,14 @@ func NewContext(trigger trigger.Type, bc blockchainer.Blockchainer, d dao.DAO, n // it's supposed to be inited once for all interopContexts, so it doesn't use // vm.InteropFuncPrice directly. type Function struct { - ID uint32 - Name string - Func func(*Context, *vm.VM) error - Price int64 + ID uint32 + Name string + Func func(*Context, *vm.VM) error + // DisallowCallback is true iff syscall can't be used in a callback. + DisallowCallback bool + // ParamCount is a number of function parameters. + ParamCount int + Price int64 // RequiredFlags is a set of flags which must be set during script invocations. // Default value is NoneFlag i.e. no flags are required. RequiredFlags smartcontract.CallFlag @@ -124,3 +135,46 @@ func (c *ContractMD) AddEvent(name string, ps ...manifest.Parameter) { Parameters: ps, }) } + +// Sort sorts interop functions by id. +func Sort(fs []Function) { + sort.Slice(fs, func(i, j int) bool { return fs[i].ID < fs[j].ID }) +} + +// GetFunction returns metadata for interop with the specified id. +func (ic *Context) GetFunction(id uint32) *Function { + for _, slice := range ic.Functions { + n := sort.Search(len(slice), func(i int) bool { + return slice[i].ID >= id + }) + if n < len(slice) && slice[n].ID == id { + return &slice[n] + } + } + return nil +} + +// SyscallHandler handles syscall with id. +func (ic *Context) SyscallHandler(v *vm.VM, id uint32) error { + f := ic.GetFunction(id) + if f == nil { + return errors.New("syscall not found") + } + cf := v.Context().GetCallFlags() + if !cf.Has(f.RequiredFlags) { + return fmt.Errorf("missing call flags: %05b vs %05b", cf, f.RequiredFlags) + } + if !v.AddGas(f.Price) { + return errors.New("insufficient amount of gas") + } + return f.Func(ic, v) +} + +// SpawnVM spawns new VM with the specified gas limit. +func (ic *Context) SpawnVM() *vm.VM { + v := vm.NewWithTrigger(ic.Trigger) + v.GasLimit = -1 + v.SyscallHandler = ic.SyscallHandler + ic.ScriptGetter = v + return v +} diff --git a/pkg/core/interop/crypto/ecdsa_test.go b/pkg/core/interop/crypto/ecdsa_test.go index 4dd735e87..c2546c81c 100644 --- a/pkg/core/interop/crypto/ecdsa_test.go +++ b/pkg/core/interop/crypto/ecdsa_test.go @@ -59,10 +59,9 @@ func initCHECKMULTISIGVM(t *testing.T, n int, ik, is []int) *vm.VM { buf[0] = byte(opcode.SYSCALL) binary.LittleEndian.PutUint32(buf[1:], ecdsaSecp256r1CheckMultisigID) - v := vm.New() - v.GasLimit = -1 ic := &interop.Context{Trigger: trigger.Verification} - v.RegisterInteropGetter(GetInterop(ic)) + Register(ic) + v := ic.SpawnVM() v.LoadScript(buf) msg := []byte("NEO - An Open Network For Smart Economy") diff --git a/pkg/core/interop/crypto/hash_test.go b/pkg/core/interop/crypto/hash_test.go index 566bcfb6e..9f4bdc22c 100644 --- a/pkg/core/interop/crypto/hash_test.go +++ b/pkg/core/interop/crypto/hash_test.go @@ -7,7 +7,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,9 +19,9 @@ func TestSHA256(t *testing.T) { emit.Bytes(buf.BinWriter, []byte{1, 0}) emit.Syscall(buf.BinWriter, "Neo.Crypto.SHA256") prog := buf.Bytes() - v := vm.New() ic := &interop.Context{Trigger: trigger.Verification} - v.RegisterInteropGetter(GetInterop(ic)) + Register(ic) + v := ic.SpawnVM() v.Load(prog) require.NoError(t, v.Run()) assert.Equal(t, 1, v.Estack().Len()) @@ -36,9 +35,9 @@ func TestRIPEMD160(t *testing.T) { emit.Bytes(buf.BinWriter, []byte{1, 0}) emit.Syscall(buf.BinWriter, "Neo.Crypto.RIPEMD160") prog := buf.Bytes() - v := vm.New() ic := &interop.Context{Trigger: trigger.Verification} - v.RegisterInteropGetter(GetInterop(ic)) + Register(ic) + v := ic.SpawnVM() v.Load(prog) require.NoError(t, v.Run()) assert.Equal(t, 1, v.Estack().Len()) diff --git a/pkg/core/interop/crypto/interop.go b/pkg/core/interop/crypto/interop.go index aff3410d4..a27850577 100644 --- a/pkg/core/interop/crypto/interop.go +++ b/pkg/core/interop/crypto/interop.go @@ -2,47 +2,32 @@ package crypto import ( "github.com/nspcc-dev/neo-go/pkg/core/interop" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" ) var ( ecdsaSecp256r1VerifyID = emit.InteropNameToID([]byte("Neo.Crypto.VerifyWithECDsaSecp256r1")) + ecdsaSecp256k1VerifyID = emit.InteropNameToID([]byte("Neo.Crypto.VerifyWithECDsaSecp256k1")) ecdsaSecp256r1CheckMultisigID = emit.InteropNameToID([]byte("Neo.Crypto.CheckMultisigWithECDsaSecp256r1")) + ecdsaSecp256k1CheckMultisigID = emit.InteropNameToID([]byte("Neo.Crypto.CheckMultisigWithECDsaSecp256k1")) sha256ID = emit.InteropNameToID([]byte("Neo.Crypto.SHA256")) ripemd160ID = emit.InteropNameToID([]byte("Neo.Crypto.RIPEMD160")) ) -// GetInterop returns interop getter for crypto-related stuff. -func GetInterop(ic *interop.Context) func(uint32) *vm.InteropFuncPrice { - return func(id uint32) *vm.InteropFuncPrice { - switch id { - case ecdsaSecp256r1VerifyID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return ECDSASecp256r1Verify(ic, v) - }, - } - case ecdsaSecp256r1CheckMultisigID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return ECDSASecp256r1CheckMultisig(ic, v) - }, - } - case sha256ID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return Sha256(ic, v) - }, - } - case ripemd160ID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return RipeMD160(ic, v) - }, - } - default: - return nil - } - } +var cryptoInterops = []interop.Function{ + {ID: ecdsaSecp256r1VerifyID, Func: ECDSASecp256r1Verify}, + {ID: ecdsaSecp256k1VerifyID, Func: ECDSASecp256k1Verify}, + {ID: ecdsaSecp256r1CheckMultisigID, Func: ECDSASecp256r1CheckMultisig}, + {ID: ecdsaSecp256k1CheckMultisigID, Func: ECDSASecp256k1CheckMultisig}, + {ID: sha256ID, Func: Sha256}, + {ID: ripemd160ID, Func: RipeMD160}, +} + +func init() { + interop.Sort(cryptoInterops) +} + +// Register adds crypto interops to ic. +func Register(ic *interop.Context) { + ic.Functions = append(ic.Functions, cryptoInterops) } diff --git a/pkg/core/interop/json/json_test.go b/pkg/core/interop/json/json_test.go index b1d4bb34f..12a86edcc 100644 --- a/pkg/core/interop/json/json_test.go +++ b/pkg/core/interop/json/json_test.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "testing" - "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -16,19 +16,13 @@ var ( deserializeID = emit.InteropNameToID([]byte("System.Json.Deserialize")) ) -func getInterop(id uint32) *vm.InteropFuncPrice { - switch id { - case serializeID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { return Serialize(nil, v) }, - } - case deserializeID: - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { return Deserialize(nil, v) }, - } - default: - return nil - } +var jsonInterops = []interop.Function{ + {ID: serializeID, Func: Serialize}, + {ID: deserializeID, Func: Deserialize}, +} + +func init() { + interop.Sort(jsonInterops) } func getTestFunc(id uint32, arg interface{}, result interface{}) func(t *testing.T) { @@ -37,8 +31,9 @@ func getTestFunc(id uint32, arg interface{}, result interface{}) func(t *testing binary.LittleEndian.PutUint32(prog[1:], id) return func(t *testing.T) { - v := vm.New() - v.RegisterInteropGetter(getInterop) + ic := &interop.Context{} + ic.Functions = append(ic.Functions, jsonInterops) + v := ic.SpawnVM() v.LoadScript(prog) v.Estack().PushVal(arg) if result == nil { diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index 33f9414b1..aab9e5159 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -6,6 +6,8 @@ import ( "github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/neo-go/pkg/config/netmode" + "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/core/interop/callback" "github.com/nspcc-dev/neo-go/pkg/core/interop/runtime" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -324,7 +326,8 @@ func TestBlockchainGetContractState(t *testing.T) { }) } -func getTestContractState() *state.Contract { +// getTestContractState returns 2 contracts second of which is allowed to call the first. +func getTestContractState() (*state.Contract, *state.Contract) { script := []byte{ byte(opcode.ABORT), // abort if no offset was provided byte(opcode.ADD), byte(opcode.RET), @@ -370,45 +373,51 @@ func getTestContractState() *state.Contract { ReturnType: smartcontract.IntegerType, }, } - return &state.Contract{ + cs := &state.Contract{ Script: script, Manifest: *m, ID: 42, } -} -func TestContractCall(t *testing.T) { - v, ic, bc := createVM(t) - defer bc.Close() - - cs := getTestContractState() - require.NoError(t, ic.DAO.PutContractState(cs)) - - currScript := []byte{byte(opcode.NOP)} - initVM := func(v *vm.VM) { - v.Istack().Clear() - v.Estack().Clear() - v.Load(currScript) - v.Estack().PushVal(42) // canary - } - - h := cs.Manifest.ABI.Hash - m := manifest.NewManifest(hash.Hash160(currScript)) + currScript := []byte{byte(opcode.RET)} + m = manifest.NewManifest(hash.Hash160(currScript)) perm := manifest.NewPermission(manifest.PermissionHash, h) perm.Methods.Add("add") perm.Methods.Add("drop") perm.Methods.Add("add3") m.Permissions = append(m.Permissions, *perm) - require.NoError(t, ic.DAO.PutContractState(&state.Contract{ + return cs, &state.Contract{ Script: currScript, Manifest: *m, ID: 123, - })) + } +} + +func loadScript(script []byte, args ...interface{}) *vm.VM { + v := vm.New() + v.LoadScriptWithFlags(script, smartcontract.AllowCall) + for i := range args { + v.Estack().PushVal(args[i]) + } + v.GasLimit = -1 + return v +} + +func TestContractCall(t *testing.T) { + _, ic, bc := createVM(t) + defer bc.Close() + + cs, currCs := getTestContractState() + require.NoError(t, ic.DAO.PutContractState(cs)) + require.NoError(t, ic.DAO.PutContractState(currCs)) + + currScript := currCs.Script + h := cs.Manifest.ABI.Hash addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)}) t.Run("Good", func(t *testing.T) { - initVM(v) + v := loadScript(currScript, 42) v.Estack().PushVal(addArgs) v.Estack().PushVal("add") v.Estack().PushVal(h.BytesBE()) @@ -420,7 +429,7 @@ func TestContractCall(t *testing.T) { }) t.Run("CallExInvalidFlag", func(t *testing.T) { - initVM(v) + v := loadScript(currScript, 42) v.Estack().PushVal(byte(0xFF)) v.Estack().PushVal(addArgs) v.Estack().PushVal("add") @@ -430,7 +439,7 @@ func TestContractCall(t *testing.T) { runInvalid := func(args ...interface{}) func(t *testing.T) { return func(t *testing.T) { - initVM(v) + v := loadScript(currScript, 42) for i := range args { v.Estack().PushVal(args[i]) } @@ -450,7 +459,7 @@ func TestContractCall(t *testing.T) { }) t.Run("IsolatedStack", func(t *testing.T) { - initVM(v) + v := loadScript(currScript, 42) v.Estack().PushVal(stackitem.NewArray(nil)) v.Estack().PushVal("drop") v.Estack().PushVal(h.BytesBE()) @@ -461,7 +470,7 @@ func TestContractCall(t *testing.T) { t.Run("CallInitialize", func(t *testing.T) { t.Run("Directly", runInvalid(stackitem.NewArray([]stackitem.Item{}), "_initialize", h.BytesBE())) - initVM(v) + v := loadScript(currScript, 42) v.Estack().PushVal(stackitem.NewArray([]stackitem.Item{stackitem.Make(5)})) v.Estack().PushVal("add3") v.Estack().PushVal(h.BytesBE()) @@ -733,3 +742,147 @@ func TestContractGetCallFlags(t *testing.T) { require.NoError(t, contractGetCallFlags(ic, v)) require.Equal(t, int64(smartcontract.All), v.Estack().Pop().Value().(*big.Int).Int64()) } + +func TestPointerCallback(t *testing.T) { + _, ic, bc := createVM(t) + defer bc.Close() + + script := []byte{ + byte(opcode.NOP), byte(opcode.INC), byte(opcode.RET), + byte(opcode.DIV), byte(opcode.RET), + } + t.Run("Good", func(t *testing.T) { + v := loadScript(script, 2, stackitem.NewPointer(3, script)) + v.Estack().PushVal(v.Context()) + require.NoError(t, callback.Create(ic, v)) + + args := stackitem.NewArray([]stackitem.Item{stackitem.Make(3), stackitem.Make(12)}) + v.Estack().InsertAt(vm.NewElement(args), 1) + require.NoError(t, callback.Invoke(ic, v)) + + require.NoError(t, v.Run()) + require.Equal(t, 1, v.Estack().Len()) + require.Equal(t, big.NewInt(5), v.Estack().Pop().Item().Value()) + }) + t.Run("Invalid", func(t *testing.T) { + t.Run("NotEnoughParameters", func(t *testing.T) { + v := loadScript(script, 2, stackitem.NewPointer(3, script)) + v.Estack().PushVal(v.Context()) + require.NoError(t, callback.Create(ic, v)) + + args := stackitem.NewArray([]stackitem.Item{stackitem.Make(3)}) + v.Estack().InsertAt(vm.NewElement(args), 1) + require.Error(t, callback.Invoke(ic, v)) + }) + }) + +} + +func TestMethodCallback(t *testing.T) { + _, ic, bc := createVM(t) + defer bc.Close() + + cs, currCs := getTestContractState() + require.NoError(t, ic.DAO.PutContractState(cs)) + require.NoError(t, ic.DAO.PutContractState(currCs)) + + ic.Functions = append(ic.Functions, systemInterops) + rawHash := cs.Manifest.ABI.Hash.BytesBE() + + t.Run("Invalid", func(t *testing.T) { + runInvalid := func(args ...interface{}) func(t *testing.T) { + return func(t *testing.T) { + v := loadScript(currCs.Script, 42) + for i := range args { + v.Estack().PushVal(args[i]) + } + require.Error(t, callback.CreateFromMethod(ic, v)) + } + } + t.Run("Hash", runInvalid("add", rawHash[1:])) + t.Run("MissingHash", runInvalid("add", util.Uint160{}.BytesBE())) + t.Run("MissingMethod", runInvalid("sub", rawHash)) + t.Run("DisallowedMethod", runInvalid("ret7", rawHash)) + t.Run("Initialize", runInvalid("_initialize", rawHash)) + t.Run("NotEnoughArguments", func(t *testing.T) { + v := loadScript(currCs.Script, 42, "add", rawHash) + require.NoError(t, callback.CreateFromMethod(ic, v)) + + v.Estack().InsertAt(vm.NewElement(stackitem.NewArray([]stackitem.Item{stackitem.Make(1)})), 1) + require.Error(t, callback.Invoke(ic, v)) + }) + t.Run("CallIsNotAllowed", func(t *testing.T) { + v := vm.New() + v.Load(currCs.Script) + v.Estack().PushVal("add") + v.Estack().PushVal(rawHash) + require.NoError(t, callback.CreateFromMethod(ic, v)) + + args := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(5)}) + v.Estack().InsertAt(vm.NewElement(args), 1) + require.Error(t, callback.Invoke(ic, v)) + }) + }) + + t.Run("Good", func(t *testing.T) { + v := loadScript(currCs.Script, 42, "add", rawHash) + require.NoError(t, callback.CreateFromMethod(ic, v)) + + args := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(5)}) + v.Estack().InsertAt(vm.NewElement(args), 1) + + require.NoError(t, callback.Invoke(ic, v)) + require.NoError(t, v.Run()) + require.Equal(t, 2, v.Estack().Len()) + require.Equal(t, big.NewInt(6), v.Estack().Pop().Item().Value()) + require.Equal(t, big.NewInt(42), v.Estack().Pop().Item().Value()) + }) +} +func TestSyscallCallback(t *testing.T) { + _, ic, bc := createVM(t) + defer bc.Close() + + ic.Functions = append(ic.Functions, []interop.Function{ + { + ID: 0x42, + Func: func(_ *interop.Context, v *vm.VM) error { + a := v.Estack().Pop().BigInt() + b := v.Estack().Pop().BigInt() + v.Estack().PushVal(new(big.Int).Add(a, b)) + return nil + }, + ParamCount: 2, + }, + { + ID: 0x53, + Func: func(_ *interop.Context, _ *vm.VM) error { return nil }, + DisallowCallback: true, + }, + }) + + t.Run("Good", func(t *testing.T) { + args := stackitem.NewArray([]stackitem.Item{stackitem.Make(12), stackitem.Make(30)}) + v := loadScript([]byte{byte(opcode.RET)}, args, 0x42) + require.NoError(t, callback.CreateFromSyscall(ic, v)) + require.NoError(t, callback.Invoke(ic, v)) + require.Equal(t, 1, v.Estack().Len()) + require.Equal(t, big.NewInt(42), v.Estack().Pop().Item().Value()) + }) + + t.Run("Invalid", func(t *testing.T) { + t.Run("InvalidParameterCount", func(t *testing.T) { + args := stackitem.NewArray([]stackitem.Item{stackitem.Make(12)}) + v := loadScript([]byte{byte(opcode.RET)}, args, 0x42) + require.NoError(t, callback.CreateFromSyscall(ic, v)) + require.Error(t, callback.Invoke(ic, v)) + }) + t.Run("MissingSyscall", func(t *testing.T) { + v := loadScript([]byte{byte(opcode.RET)}, stackitem.NewArray(nil), 0x43) + require.Error(t, callback.CreateFromSyscall(ic, v)) + }) + t.Run("Disallowed", func(t *testing.T) { + v := loadScript([]byte{byte(opcode.RET)}, stackitem.NewArray(nil), 0x53) + require.Error(t, callback.CreateFromSyscall(ic, v)) + }) + }) +} diff --git a/pkg/core/interops.go b/pkg/core/interops.go index 2ee8d780f..6b92e703d 100644 --- a/pkg/core/interops.go +++ b/pkg/core/interops.go @@ -8,9 +8,8 @@ package core */ import ( - "sort" - "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/core/interop/callback" "github.com/nspcc-dev/neo-go/pkg/core/interop/crypto" "github.com/nspcc-dev/neo-go/pkg/core/interop/enumerator" "github.com/nspcc-dev/neo-go/pkg/core/interop/iterator" @@ -25,109 +24,101 @@ import ( // SpawnVM returns a VM with script getter and interop functions set // up for current blockchain. func SpawnVM(ic *interop.Context) *vm.VM { - vm := vm.NewWithTrigger(ic.Trigger) - vm.RegisterInteropGetter(getSystemInterop(ic)) - vm.RegisterInteropGetter(getNeoInterop(ic)) - ic.ScriptGetter = vm + vm := ic.SpawnVM() + ic.Functions = [][]interop.Function{systemInterops, neoInterops} return vm } -// getSystemInterop returns matching interop function from the System namespace -// for a given id in the current context. -func getSystemInterop(ic *interop.Context) vm.InteropGetterFunc { - return getInteropFromSlice(ic, systemInterops) -} - -// getNeoInterop returns matching interop function from the Neo and AntShares -// namespaces for a given id in the current context. -func getNeoInterop(ic *interop.Context) vm.InteropGetterFunc { - return getInteropFromSlice(ic, neoInterops) -} - -// getInteropFromSlice returns matching interop function from the given slice of -// interop functions in the current context. -func getInteropFromSlice(ic *interop.Context, slice []interop.Function) func(uint32) *vm.InteropFuncPrice { - return func(id uint32) *vm.InteropFuncPrice { - n := sort.Search(len(slice), func(i int) bool { - return slice[i].ID >= id - }) - if n < len(slice) && slice[n].ID == id { - return &vm.InteropFuncPrice{ - Func: func(v *vm.VM) error { - return slice[n].Func(ic, v) - }, - Price: slice[n].Price, - RequiredFlags: slice[n].RequiredFlags, - } - } - return nil - } -} - // All lists are sorted, keep 'em this way, please. var systemInterops = []interop.Function{ - {Name: "System.Binary.Base64Decode", Func: runtimeDecode, Price: 100000}, - {Name: "System.Binary.Base64Encode", Func: runtimeEncode, Price: 100000}, - {Name: "System.Binary.Deserialize", Func: runtimeDeserialize, Price: 500000}, - {Name: "System.Binary.Serialize", Func: runtimeSerialize, Price: 100000}, - {Name: "System.Blockchain.GetBlock", Func: bcGetBlock, Price: 2500000, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Blockchain.GetContract", Func: bcGetContract, Price: 1000000, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Blockchain.GetHeight", Func: bcGetHeight, Price: 400, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Blockchain.GetTransaction", Func: bcGetTransaction, Price: 1000000, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Blockchain.GetTransactionFromBlock", Func: bcGetTransactionFromBlock, Price: 1000000, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Blockchain.GetTransactionHeight", Func: bcGetTransactionHeight, Price: 1000000, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Contract.Call", Func: contractCall, Price: 1000000, RequiredFlags: smartcontract.AllowCall}, - {Name: "System.Contract.CallEx", Func: contractCallEx, Price: 1000000, RequiredFlags: smartcontract.AllowCall}, - {Name: "System.Contract.Create", Func: contractCreate, Price: 0, RequiredFlags: smartcontract.AllowModifyStates}, - {Name: "System.Contract.CreateStandardAccount", Func: contractCreateStandardAccount, Price: 10000}, - {Name: "System.Contract.Destroy", Func: contractDestroy, Price: 1000000, RequiredFlags: smartcontract.AllowModifyStates}, - {Name: "System.Contract.IsStandard", Func: contractIsStandard, Price: 30000}, - {Name: "System.Contract.GetCallFlags", Func: contractGetCallFlags, Price: 30000}, - {Name: "System.Contract.Update", Func: contractUpdate, Price: 0, RequiredFlags: smartcontract.AllowModifyStates}, - {Name: "System.Enumerator.Concat", Func: enumerator.Concat, Price: 400}, - {Name: "System.Enumerator.Create", Func: enumerator.Create, Price: 400}, - {Name: "System.Enumerator.Next", Func: enumerator.Next, Price: 1000000}, - {Name: "System.Enumerator.Value", Func: enumerator.Value, Price: 400}, - {Name: "System.Iterator.Concat", Func: iterator.Concat, Price: 400}, - {Name: "System.Iterator.Create", Func: iterator.Create, Price: 400}, - {Name: "System.Iterator.Key", Func: iterator.Key, Price: 400}, - {Name: "System.Iterator.Keys", Func: iterator.Keys, Price: 400}, - {Name: "System.Iterator.Values", Func: iterator.Values, Price: 400}, - {Name: "System.Json.Deserialize", Func: json.Deserialize, Price: 500000}, - {Name: "System.Json.Serialize", Func: json.Serialize, Price: 100000}, - {Name: "System.Runtime.CheckWitness", Func: runtime.CheckWitness, Price: 30000, + {Name: "System.Binary.Base64Decode", Func: runtimeDecode, Price: 100000, ParamCount: 1}, + {Name: "System.Binary.Base64Encode", Func: runtimeEncode, Price: 100000, ParamCount: 1}, + {Name: "System.Binary.Deserialize", Func: runtimeDeserialize, Price: 500000, ParamCount: 1}, + {Name: "System.Binary.Serialize", Func: runtimeSerialize, Price: 100000, ParamCount: 1}, + {Name: "System.Blockchain.GetBlock", Func: bcGetBlock, Price: 2500000, + RequiredFlags: smartcontract.AllowStates, ParamCount: 1}, + {Name: "System.Blockchain.GetContract", Func: bcGetContract, Price: 1000000, + RequiredFlags: smartcontract.AllowStates, ParamCount: 1}, + {Name: "System.Blockchain.GetHeight", Func: bcGetHeight, Price: 400, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Blockchain.GetTransaction", Func: bcGetTransaction, Price: 1000000, + RequiredFlags: smartcontract.AllowStates, ParamCount: 1}, + {Name: "System.Blockchain.GetTransactionFromBlock", Func: bcGetTransactionFromBlock, Price: 1000000, + RequiredFlags: smartcontract.AllowStates, ParamCount: 2}, + {Name: "System.Blockchain.GetTransactionHeight", Func: bcGetTransactionHeight, Price: 1000000, + RequiredFlags: smartcontract.AllowStates, ParamCount: 1}, + {Name: "System.Callback.Create", Func: callback.Create, Price: 400, ParamCount: 3, DisallowCallback: true}, + {Name: "System.Callback.CreateFromMethod", Func: callback.CreateFromMethod, Price: 1000000, ParamCount: 2, DisallowCallback: true}, + {Name: "System.Callback.CreateFromSyscall", Func: callback.CreateFromSyscall, Price: 400, ParamCount: 1, DisallowCallback: true}, + {Name: "System.Callback.Invoke", Func: callback.Invoke, Price: 1000000, ParamCount: 2, DisallowCallback: true}, + {Name: "System.Contract.Call", Func: contractCall, Price: 1000000, + RequiredFlags: smartcontract.AllowCall, ParamCount: 3, DisallowCallback: true}, + {Name: "System.Contract.CallEx", Func: contractCallEx, Price: 1000000, + RequiredFlags: smartcontract.AllowCall, ParamCount: 4, DisallowCallback: true}, + {Name: "System.Contract.Create", Func: contractCreate, Price: 0, + RequiredFlags: smartcontract.AllowModifyStates, ParamCount: 2, DisallowCallback: true}, + {Name: "System.Contract.CreateStandardAccount", Func: contractCreateStandardAccount, Price: 10000, ParamCount: 1, DisallowCallback: true}, + {Name: "System.Contract.Destroy", Func: contractDestroy, Price: 1000000, RequiredFlags: smartcontract.AllowModifyStates, DisallowCallback: true}, + {Name: "System.Contract.IsStandard", Func: contractIsStandard, Price: 30000, ParamCount: 1}, + {Name: "System.Contract.GetCallFlags", Func: contractGetCallFlags, Price: 30000, DisallowCallback: true}, + {Name: "System.Contract.Update", Func: contractUpdate, Price: 0, + RequiredFlags: smartcontract.AllowModifyStates, ParamCount: 2, DisallowCallback: true}, + {Name: "System.Enumerator.Concat", Func: enumerator.Concat, Price: 400, ParamCount: 2, DisallowCallback: true}, + {Name: "System.Enumerator.Create", Func: enumerator.Create, Price: 400, ParamCount: 1, DisallowCallback: true}, + {Name: "System.Enumerator.Next", Func: enumerator.Next, Price: 1000000, ParamCount: 1, DisallowCallback: true}, + {Name: "System.Enumerator.Value", Func: enumerator.Value, Price: 400, ParamCount: 1, DisallowCallback: true}, + {Name: "System.Iterator.Concat", Func: iterator.Concat, Price: 400, ParamCount: 2, DisallowCallback: true}, + {Name: "System.Iterator.Create", Func: iterator.Create, Price: 400, ParamCount: 1, DisallowCallback: true}, + {Name: "System.Iterator.Key", Func: iterator.Key, Price: 400, ParamCount: 1, DisallowCallback: true}, + {Name: "System.Iterator.Keys", Func: iterator.Keys, Price: 400, ParamCount: 1, DisallowCallback: true}, + {Name: "System.Iterator.Values", Func: iterator.Values, Price: 400, ParamCount: 1, DisallowCallback: true}, + {Name: "System.Json.Deserialize", Func: json.Deserialize, Price: 500000, ParamCount: 1}, + {Name: "System.Json.Serialize", Func: json.Serialize, Price: 100000, ParamCount: 1}, + {Name: "System.Runtime.CheckWitness", Func: runtime.CheckWitness, Price: 30000, + RequiredFlags: smartcontract.AllowStates, ParamCount: 1}, {Name: "System.Runtime.GasLeft", Func: runtime.GasLeft, Price: 400}, {Name: "System.Runtime.GetCallingScriptHash", Func: engineGetCallingScriptHash, Price: 400}, {Name: "System.Runtime.GetEntryScriptHash", Func: engineGetEntryScriptHash, Price: 400}, {Name: "System.Runtime.GetExecutingScriptHash", Func: engineGetExecutingScriptHash, Price: 400}, {Name: "System.Runtime.GetInvocationCounter", Func: runtime.GetInvocationCounter, Price: 400}, - {Name: "System.Runtime.GetNotifications", Func: runtime.GetNotifications, Price: 10000}, + {Name: "System.Runtime.GetNotifications", Func: runtime.GetNotifications, Price: 10000, ParamCount: 1}, {Name: "System.Runtime.GetScriptContainer", Func: engineGetScriptContainer, Price: 250}, {Name: "System.Runtime.GetTime", Func: runtimeGetTime, Price: 250, RequiredFlags: smartcontract.AllowStates}, {Name: "System.Runtime.GetTrigger", Func: runtimeGetTrigger, Price: 250}, - {Name: "System.Runtime.Log", Func: runtimeLog, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}, - {Name: "System.Runtime.Notify", Func: runtimeNotify, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}, + {Name: "System.Runtime.Log", Func: runtimeLog, Price: 1000000, RequiredFlags: smartcontract.AllowNotify, + ParamCount: 1, DisallowCallback: true}, + {Name: "System.Runtime.Notify", Func: runtimeNotify, Price: 1000000, RequiredFlags: smartcontract.AllowNotify, + ParamCount: 2, DisallowCallback: true}, {Name: "System.Runtime.Platform", Func: runtimePlatform, Price: 250}, - {Name: "System.Storage.Delete", Func: storageDelete, Price: StoragePrice, RequiredFlags: smartcontract.AllowModifyStates}, - {Name: "System.Storage.Find", Func: storageFind, Price: 1000000, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Storage.Get", Func: storageGet, Price: 1000000, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Storage.GetContext", Func: storageGetContext, Price: 400, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Storage.GetReadOnlyContext", Func: storageGetReadOnlyContext, Price: 400, RequiredFlags: smartcontract.AllowStates}, - {Name: "System.Storage.Put", Func: storagePut, Price: 0, RequiredFlags: smartcontract.AllowModifyStates}, // These don't have static price in C# code. - {Name: "System.Storage.PutEx", Func: storagePutEx, Price: 0, RequiredFlags: smartcontract.AllowModifyStates}, - {Name: "System.Storage.AsReadOnly", Func: storageContextAsReadOnly, Price: 400, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Storage.Delete", Func: storageDelete, Price: StoragePrice, + RequiredFlags: smartcontract.AllowModifyStates, ParamCount: 2, DisallowCallback: true}, + {Name: "System.Storage.Find", Func: storageFind, Price: 1000000, RequiredFlags: smartcontract.AllowStates, + ParamCount: 2, DisallowCallback: true}, + {Name: "System.Storage.Get", Func: storageGet, Price: 1000000, RequiredFlags: smartcontract.AllowStates, + ParamCount: 2, DisallowCallback: true}, + {Name: "System.Storage.GetContext", Func: storageGetContext, Price: 400, + RequiredFlags: smartcontract.AllowStates, DisallowCallback: true}, + {Name: "System.Storage.GetReadOnlyContext", Func: storageGetReadOnlyContext, Price: 400, + RequiredFlags: smartcontract.AllowStates, DisallowCallback: true}, + {Name: "System.Storage.Put", Func: storagePut, Price: 0, RequiredFlags: smartcontract.AllowModifyStates, + ParamCount: 3, DisallowCallback: true}, // These don't have static price in C# code. + {Name: "System.Storage.PutEx", Func: storagePutEx, Price: 0, RequiredFlags: smartcontract.AllowModifyStates, + ParamCount: 4, DisallowCallback: true}, + {Name: "System.Storage.AsReadOnly", Func: storageContextAsReadOnly, Price: 400, + RequiredFlags: smartcontract.AllowStates, ParamCount: 1, DisallowCallback: true}, } var neoInterops = []interop.Function{ - {Name: "Neo.Crypto.VerifyWithECDsaSecp256r1", Func: crypto.ECDSASecp256r1Verify, Price: crypto.ECDSAVerifyPrice}, - {Name: "Neo.Crypto.VerifyWithECDsaSecp256k1", Func: crypto.ECDSASecp256k1Verify, Price: crypto.ECDSAVerifyPrice}, - {Name: "Neo.Crypto.CheckMultisigWithECDsaSecp256r1", Func: crypto.ECDSASecp256r1CheckMultisig, Price: 0}, - {Name: "Neo.Crypto.CheckMultisigWithECDsaSecp256k1", Func: crypto.ECDSASecp256k1CheckMultisig, Price: 0}, - {Name: "Neo.Crypto.SHA256", Func: crypto.Sha256, Price: 1000000}, - {Name: "Neo.Crypto.RIPEMD160", Func: crypto.RipeMD160, Price: 1000000}, - {Name: "Neo.Native.Call", Func: native.Call, Price: 0}, - {Name: "Neo.Native.Deploy", Func: native.Deploy, Price: 0, RequiredFlags: smartcontract.AllowModifyStates}, + {Name: "Neo.Crypto.VerifyWithECDsaSecp256r1", Func: crypto.ECDSASecp256r1Verify, + Price: crypto.ECDSAVerifyPrice, ParamCount: 3}, + {Name: "Neo.Crypto.VerifyWithECDsaSecp256k1", Func: crypto.ECDSASecp256k1Verify, + Price: crypto.ECDSAVerifyPrice, ParamCount: 3}, + {Name: "Neo.Crypto.CheckMultisigWithECDsaSecp256r1", Func: crypto.ECDSASecp256r1CheckMultisig, Price: 0, ParamCount: 3}, + {Name: "Neo.Crypto.CheckMultisigWithECDsaSecp256k1", Func: crypto.ECDSASecp256k1CheckMultisig, Price: 0, ParamCount: 3}, + {Name: "Neo.Crypto.SHA256", Func: crypto.Sha256, Price: 1000000, ParamCount: 1}, + {Name: "Neo.Crypto.RIPEMD160", Func: crypto.RipeMD160, Price: 1000000, ParamCount: 1}, + {Name: "Neo.Native.Call", Func: native.Call, Price: 0, ParamCount: 1, DisallowCallback: true}, + {Name: "Neo.Native.Deploy", Func: native.Deploy, Price: 0, RequiredFlags: smartcontract.AllowModifyStates, DisallowCallback: true}, } // initIDinInteropsSlice initializes IDs from names in one given @@ -136,9 +127,7 @@ func initIDinInteropsSlice(iops []interop.Function) { for i := range iops { iops[i].ID = emit.InteropNameToID([]byte(iops[i].Name)) } - sort.Slice(iops, func(i, j int) bool { - return iops[i].ID < iops[j].ID - }) + interop.Sort(iops) } // init initializes IDs in the global interop slices. diff --git a/pkg/core/opcode_price.go b/pkg/core/opcode_price.go index d870d6bf6..848826f07 100644 --- a/pkg/core/opcode_price.go +++ b/pkg/core/opcode_price.go @@ -14,65 +14,65 @@ func opcodePrice(opcodes ...opcode.Opcode) int64 { } var prices = map[opcode.Opcode]int64{ - opcode.PUSHINT8: 30, - opcode.PUSHINT32: 30, - opcode.PUSHINT64: 30, - opcode.PUSHINT16: 30, - opcode.PUSHINT128: 120, - opcode.PUSHINT256: 120, - opcode.PUSHA: 120, - opcode.PUSHNULL: 30, - opcode.PUSHDATA1: 180, - opcode.PUSHDATA2: 13000, - opcode.PUSHDATA4: 110000, - opcode.PUSHM1: 30, - opcode.PUSH0: 30, - opcode.PUSH1: 30, - opcode.PUSH2: 30, - opcode.PUSH3: 30, - opcode.PUSH4: 30, - opcode.PUSH5: 30, - opcode.PUSH6: 30, - opcode.PUSH7: 30, - opcode.PUSH8: 30, - opcode.PUSH9: 30, - opcode.PUSH10: 30, - opcode.PUSH11: 30, - opcode.PUSH12: 30, - opcode.PUSH13: 30, - opcode.PUSH14: 30, - opcode.PUSH15: 30, - opcode.PUSH16: 30, - opcode.NOP: 30, - opcode.JMP: 70, - opcode.JMPL: 70, - opcode.JMPIF: 70, - opcode.JMPIFL: 70, - opcode.JMPIFNOT: 70, - opcode.JMPIFNOTL: 70, - opcode.JMPEQ: 70, - opcode.JMPEQL: 70, - opcode.JMPNE: 70, - opcode.JMPNEL: 70, - opcode.JMPGT: 70, - opcode.JMPGTL: 70, - opcode.JMPGE: 70, - opcode.JMPGEL: 70, - opcode.JMPLT: 70, - opcode.JMPLTL: 70, - opcode.JMPLE: 70, - opcode.JMPLEL: 70, - opcode.CALL: 22000, - opcode.CALLL: 22000, - opcode.CALLA: 22000, - opcode.ABORT: 30, - opcode.ASSERT: 30, - opcode.THROW: 22000, - //opcode.TRY: 100, - //opcode.TRY_L: 100, - //opcode.ENDTRY: 100, - //opcode.ENDTRY_L: 100, - //opcode.ENDFINALLY: 100, + opcode.PUSHINT8: 30, + opcode.PUSHINT32: 30, + opcode.PUSHINT64: 30, + opcode.PUSHINT16: 30, + opcode.PUSHINT128: 120, + opcode.PUSHINT256: 120, + opcode.PUSHA: 120, + opcode.PUSHNULL: 30, + opcode.PUSHDATA1: 180, + opcode.PUSHDATA2: 13000, + opcode.PUSHDATA4: 110000, + opcode.PUSHM1: 30, + opcode.PUSH0: 30, + opcode.PUSH1: 30, + opcode.PUSH2: 30, + opcode.PUSH3: 30, + opcode.PUSH4: 30, + opcode.PUSH5: 30, + opcode.PUSH6: 30, + opcode.PUSH7: 30, + opcode.PUSH8: 30, + opcode.PUSH9: 30, + opcode.PUSH10: 30, + opcode.PUSH11: 30, + opcode.PUSH12: 30, + opcode.PUSH13: 30, + opcode.PUSH14: 30, + opcode.PUSH15: 30, + opcode.PUSH16: 30, + opcode.NOP: 30, + opcode.JMP: 70, + opcode.JMPL: 70, + opcode.JMPIF: 70, + opcode.JMPIFL: 70, + opcode.JMPIFNOT: 70, + opcode.JMPIFNOTL: 70, + opcode.JMPEQ: 70, + opcode.JMPEQL: 70, + opcode.JMPNE: 70, + opcode.JMPNEL: 70, + opcode.JMPGT: 70, + opcode.JMPGTL: 70, + opcode.JMPGE: 70, + opcode.JMPGEL: 70, + opcode.JMPLT: 70, + opcode.JMPLTL: 70, + opcode.JMPLE: 70, + opcode.JMPLEL: 70, + opcode.CALL: 22000, + opcode.CALLL: 22000, + opcode.CALLA: 22000, + opcode.ABORT: 30, + opcode.ASSERT: 30, + opcode.THROW: 22000, + opcode.TRY: 100, + opcode.TRYL: 100, + opcode.ENDTRY: 100, + opcode.ENDTRYL: 100, + opcode.ENDFINALLY: 100, opcode.RET: 0, opcode.SYSCALL: 0, opcode.DEPTH: 60, diff --git a/pkg/smartcontract/context/context_test.go b/pkg/smartcontract/context/context_test.go index 5ae5e836d..309135e79 100644 --- a/pkg/smartcontract/context/context_test.go +++ b/pkg/smartcontract/context/context_test.go @@ -110,9 +110,9 @@ func TestParameterContext_AddSignatureMultisig(t *testing.T) { } func newTestVM(w *transaction.Witness, tx *transaction.Transaction) *vm.VM { - v := vm.New() - v.GasLimit = -1 - v.RegisterInteropGetter(crypto.GetInterop(&interop.Context{Container: tx})) + ic := &interop.Context{Container: tx} + crypto.Register(ic) + v := ic.SpawnVM() v.LoadScript(w.VerificationScript) v.LoadScript(w.InvocationScript) return v diff --git a/pkg/vm/interop.go b/pkg/vm/interop.go index 218d93332..bf04a2e8e 100644 --- a/pkg/vm/interop.go +++ b/pkg/vm/interop.go @@ -10,63 +10,59 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) -// InteropFunc allows to hook into the VM. -type InteropFunc func(vm *VM) error - -// InteropFuncPrice represents an interop function with a price. -type InteropFuncPrice struct { - Func InteropFunc +// interopIDFuncPrice adds an ID to the InteropFuncPrice. +type interopIDFuncPrice struct { + ID uint32 + Func func(vm *VM) error Price int64 RequiredFlags smartcontract.CallFlag } -// interopIDFuncPrice adds an ID to the InteropFuncPrice. -type interopIDFuncPrice struct { - ID uint32 - InteropFuncPrice -} - -// InteropGetterFunc is a function that returns an interop function-price -// structure by the given interop ID. -type InteropGetterFunc func(uint32) *InteropFuncPrice - var defaultVMInterops = []interopIDFuncPrice{ - {emit.InteropNameToID([]byte("System.Binary.Deserialize")), - InteropFuncPrice{Func: RuntimeDeserialize, Price: 500000}}, - {emit.InteropNameToID([]byte("System.Binary.Serialize")), - InteropFuncPrice{Func: RuntimeSerialize, Price: 100000}}, - {emit.InteropNameToID([]byte("System.Runtime.Log")), - InteropFuncPrice{Func: runtimeLog, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}}, - {emit.InteropNameToID([]byte("System.Runtime.Notify")), - InteropFuncPrice{Func: runtimeNotify, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}}, - {emit.InteropNameToID([]byte("System.Enumerator.Create")), - InteropFuncPrice{Func: EnumeratorCreate, Price: 400}}, - {emit.InteropNameToID([]byte("System.Enumerator.Next")), - InteropFuncPrice{Func: EnumeratorNext, Price: 1000000}}, - {emit.InteropNameToID([]byte("System.Enumerator.Concat")), - InteropFuncPrice{Func: EnumeratorConcat, Price: 400}}, - {emit.InteropNameToID([]byte("System.Enumerator.Value")), - InteropFuncPrice{Func: EnumeratorValue, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Create")), - InteropFuncPrice{Func: IteratorCreate, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Concat")), - InteropFuncPrice{Func: IteratorConcat, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Key")), - InteropFuncPrice{Func: IteratorKey, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Keys")), - InteropFuncPrice{Func: IteratorKeys, Price: 400}}, - {emit.InteropNameToID([]byte("System.Iterator.Values")), - InteropFuncPrice{Func: IteratorValues, Price: 400}}, + {ID: emit.InteropNameToID([]byte("System.Binary.Deserialize")), + Func: RuntimeDeserialize, Price: 500000}, + {ID: emit.InteropNameToID([]byte("System.Binary.Serialize")), + Func: RuntimeSerialize, Price: 100000}, + {ID: emit.InteropNameToID([]byte("System.Runtime.Log")), + Func: runtimeLog, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}, + {ID: emit.InteropNameToID([]byte("System.Runtime.Notify")), + Func: runtimeNotify, Price: 1000000, RequiredFlags: smartcontract.AllowNotify}, + {ID: emit.InteropNameToID([]byte("System.Enumerator.Create")), + Func: EnumeratorCreate, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Enumerator.Next")), + Func: EnumeratorNext, Price: 1000000}, + {ID: emit.InteropNameToID([]byte("System.Enumerator.Concat")), + Func: EnumeratorConcat, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Enumerator.Value")), + Func: EnumeratorValue, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Create")), + Func: IteratorCreate, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Concat")), + Func: IteratorConcat, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Key")), + Func: IteratorKey, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Keys")), + Func: IteratorKeys, Price: 400}, + {ID: emit.InteropNameToID([]byte("System.Iterator.Values")), + Func: IteratorValues, Price: 400}, } -func getDefaultVMInterop(id uint32) *InteropFuncPrice { +func init() { + sort.Slice(defaultVMInterops, func(i, j int) bool { return defaultVMInterops[i].ID < defaultVMInterops[j].ID }) +} + +func defaultSyscallHandler(v *VM, id uint32) error { n := sort.Search(len(defaultVMInterops), func(i int) bool { return defaultVMInterops[i].ID >= id }) - if n < len(defaultVMInterops) && defaultVMInterops[n].ID == id { - return &defaultVMInterops[n].InteropFuncPrice + if n >= len(defaultVMInterops) || defaultVMInterops[n].ID != id { + return errors.New("syscall not found") } - return nil + d := defaultVMInterops[n] + if !v.Context().callFlag.Has(d.RequiredFlags) { + return fmt.Errorf("missing call flags: %05b vs %05b", v.Context().callFlag, d.RequiredFlags) + } + return d.Func(v) } // runtimeLog handles the syscall "System.Runtime.Log" for printing and logging stuff. diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index 78633cc99..7a86f3a70 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -19,6 +19,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/pkg/errors" "github.com/stretchr/testify/require" ) @@ -110,30 +111,21 @@ func TestUT(t *testing.T) { require.Equal(t, true, testsRan, "neo-vm tests should be available (check submodules)") } -func getTestingInterop(id uint32) *InteropFuncPrice { - f := func(v *VM) error { - v.estack.PushVal(stackitem.NewInterop(new(int))) - return nil - } +func testSyscallHandler(v *VM, id uint32) error { switch id { case 0x77777777: - return &InteropFuncPrice{Func: f} + v.Estack().PushVal(stackitem.NewInterop(new(int))) case 0x66666666: - return &InteropFuncPrice{ - Func: f, - RequiredFlags: smartcontract.ReadOnly, + if !v.Context().callFlag.Has(smartcontract.ReadOnly) { + return errors.New("invalid call flags") } + v.Estack().PushVal(stackitem.NewInterop(new(int))) case 0x55555555: - return &InteropFuncPrice{ - Func: f, - } + v.Estack().PushVal(stackitem.NewInterop(new(int))) case 0xADDEADDE: - return &InteropFuncPrice{ - Func: func(v *VM) error { - v.throw(stackitem.Make("error")) - return nil - }, - } + v.throw(stackitem.Make("error")) + default: + return errors.New("syscall not found") } return nil } @@ -163,7 +155,7 @@ func testFile(t *testing.T, filename string) { prog := []byte(test.Script) vm := load(prog) vm.state = BreakState - vm.RegisterInteropGetter(getTestingInterop) + vm.SyscallHandler = testSyscallHandler for i := range test.Steps { execStep(t, vm, test.Steps[i]) diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 0ad2a65f6..dcef4e0e5 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -58,13 +58,13 @@ const ( maxSHLArg = stackitem.MaxBigIntegerSizeBits ) +// SyscallHandler is a type for syscall handler. +type SyscallHandler = func(*VM, uint32) error + // VM represents the virtual machine. type VM struct { state State - // callbacks to get interops. - getInterop []InteropGetterFunc - // callback to get interop price getPrice func(*VM, opcode.Opcode, []byte) int64 @@ -78,6 +78,9 @@ type VM struct { gasConsumed int64 GasLimit int64 + // SyscallHandler handles SYSCALL opcode. + SyscallHandler func(v *VM, id uint32) error + trigger trigger.Type // Public keys cache. @@ -92,17 +95,16 @@ func New() *VM { // NewWithTrigger returns a new VM for executions triggered by t. func NewWithTrigger(t trigger.Type) *VM { vm := &VM{ - getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage. - state: HaltState, - istack: NewStack("invocation"), - refs: newRefCounter(), - keys: make(map[string]*keys.PublicKey), - trigger: t, + state: HaltState, + istack: NewStack("invocation"), + refs: newRefCounter(), + keys: make(map[string]*keys.PublicKey), + trigger: t, + + SyscallHandler: defaultSyscallHandler, } vm.estack = vm.newItemStack("evaluation") - - vm.RegisterInteropGetter(getDefaultVMInterop) return vm } @@ -113,13 +115,6 @@ func (v *VM) newItemStack(n string) *Stack { return s } -// RegisterInteropGetter registers the given InteropGetterFunc into VM. There -// can be many interop getters and they're probed in LIFO order wrt their -// registration time. -func (v *VM) RegisterInteropGetter(f InteropGetterFunc) { - v.getInterop = append(v.getInterop, f) -} - // SetPriceGetter registers the given PriceGetterFunc in v. // f accepts vm's Context, current instruction and instruction parameter. func (v *VM) SetPriceGetter(f func(*VM, opcode.Opcode, []byte) int64) { @@ -494,18 +489,6 @@ func GetInteropID(parameter []byte) uint32 { return binary.LittleEndian.Uint32(parameter) } -// GetInteropByID returns interop function together with price. -// Registered callbacks are checked in LIFO order. -func (v *VM) GetInteropByID(id uint32) *InteropFuncPrice { - for i := len(v.getInterop) - 1; i >= 0; i-- { - if ifunc := v.getInterop[i](id); ifunc != nil { - return ifunc - } - } - - return nil -} - // execute performs an instruction cycle in the VM. Acting on the instruction (opcode). func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err error) { // Instead of polluting the whole VM logic with error handling, we will recover @@ -1250,15 +1233,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.SYSCALL: interopID := GetInteropID(parameter) - ifunc := v.GetInteropByID(interopID) - if !v.Context().callFlag.Has(ifunc.RequiredFlags) { - panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags)) - } - - if ifunc == nil { - panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID)) - } - if err := ifunc.Func(v); err != nil { + err := v.SyscallHandler(v, interopID) + if err != nil { panic(fmt.Sprintf("failed to invoke syscall: %s", err)) } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index f4a0a94ec..2c872cd94 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -17,26 +17,25 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func fooInteropGetter(id uint32) *InteropFuncPrice { +func fooInteropHandler(v *VM, id uint32) error { if id == emit.InteropNameToID([]byte("foo")) { - return &InteropFuncPrice{ - Func: func(evm *VM) error { - evm.Estack().PushVal(1) - return nil - }, - Price: 1, + if !v.AddGas(1) { + return errors.New("invalid gas amount") } + v.Estack().PushVal(1) + return nil } - return nil + return errors.New("syscall not found") } func TestInteropHook(t *testing.T) { v := newTestVM() - v.RegisterInteropGetter(fooInteropGetter) + v.SyscallHandler = fooInteropHandler buf := io.NewBufBinWriter() emit.Syscall(buf.BinWriter, "foo") @@ -47,13 +46,6 @@ func TestInteropHook(t *testing.T) { assert.Equal(t, big.NewInt(1), v.estack.Pop().value.Value()) } -func TestRegisterInteropGetter(t *testing.T) { - v := newTestVM() - currRegistered := len(v.getInterop) - v.RegisterInteropGetter(fooInteropGetter) - assert.Equal(t, currRegistered+1, len(v.getInterop)) -} - func TestVM_SetPriceGetter(t *testing.T) { v := newTestVM() prog := []byte{ @@ -819,7 +811,7 @@ func getTestCallFlagsFunc(syscall []byte, flags smartcontract.CallFlag, result i return func(t *testing.T) { script := append([]byte{byte(opcode.SYSCALL)}, syscall...) v := newTestVM() - v.RegisterInteropGetter(getTestingInterop) + v.SyscallHandler = testSyscallHandler v.LoadScriptWithFlags(script, flags) if result == nil { checkVMFailed(t, v)