diff --git a/pkg/compiler/interop_test.go b/pkg/compiler/interop_test.go index a56ea338b..7487b1493 100644 --- a/pkg/compiler/interop_test.go +++ b/pkg/compiler/interop_test.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/encoding/address" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -64,7 +65,7 @@ func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM { b, err := compiler.Compile(strings.NewReader(src)) require.NoError(t, err) v := core.SpawnVM(ic) - v.Load(b) + v.LoadScriptWithFlags(b, smartcontract.All) return v } diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 982886bcc..c1fcb4385 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -20,6 +20,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" @@ -562,7 +563,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { systemInterop := bc.newInteropContext(trigger.Application, cache, block, tx) v := SpawnVM(systemInterop) - v.LoadScript(tx.Script) + v.LoadScriptWithFlags(tx.Script, smartcontract.All) v.SetPriceGetter(getPrice) if bc.config.FreeGasLimit > 0 { v.SetGasLimit(bc.config.FreeGasLimit + tx.SystemFee) @@ -1276,7 +1277,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa } vm := SpawnVM(interopCtx) - vm.LoadScript(verification) + vm.LoadScriptWithFlags(verification, smartcontract.ReadOnly) vm.LoadScript(witness.InvocationScript) if useKeys { bc.keyCacheLock.RLock() diff --git a/pkg/core/interop/context.go b/pkg/core/interop/context.go index a6ac1f393..10d2c376e 100644 --- a/pkg/core/interop/context.go +++ b/pkg/core/interop/context.go @@ -56,6 +56,11 @@ type Function struct { Name string Func func(*Context, *vm.VM) error Price int + // AllowedTriggers is a set of triggers which are allowed to initiate invocation. + AllowedTriggers trigger.Type + // RequiredFlags is a set of flags which must be set during script invocations. + // Default value is NoneFlag i.e. no flags are required. + RequiredFlags smartcontract.CallFlag } // Method is a signature for a native method. diff --git a/pkg/core/interop_neo.go b/pkg/core/interop_neo.go index 9b57250e6..c41258b7d 100644 --- a/pkg/core/interop_neo.go +++ b/pkg/core/interop_neo.go @@ -10,7 +10,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" - "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -63,9 +62,6 @@ func storageFind(ic *interop.Context, v *vm.VM) error { // evaluation stack, does a lot of checks and returns Contract if it // succeeds. func createContractStateFromVM(ic *interop.Context, v *vm.VM) (*state.Contract, error) { - if ic.Trigger != trigger.Application { - return nil, errors.New("can't create contract when not triggered by an application") - } script := v.Estack().Pop().Bytes() if len(script) > MaxContractScriptSize { return nil, errors.New("the script is too big") diff --git a/pkg/core/interop_system.go b/pkg/core/interop_system.go index 285060312..da634e51d 100644 --- a/pkg/core/interop_system.go +++ b/pkg/core/interop_system.go @@ -13,7 +13,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/smartcontract" - "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -272,9 +271,6 @@ func checkStorageContext(ic *interop.Context, stc *StorageContext) error { // storageDelete deletes stored key-value pair. func storageDelete(ic *interop.Context, v *vm.VM) error { - if ic.Trigger != trigger.Application && ic.Trigger != trigger.ApplicationR { - return errors.New("can't delete when the trigger is not application") - } stcInterface := v.Estack().Pop().Value() stc, ok := stcInterface.(*StorageContext) if !ok { @@ -337,9 +333,6 @@ func storageGetReadOnlyContext(ic *interop.Context, v *vm.VM) error { } func putWithContextAndFlags(ic *interop.Context, stc *StorageContext, key []byte, value []byte, isConst bool) error { - if ic.Trigger != trigger.Application && ic.Trigger != trigger.ApplicationR { - return errors.New("can't delete when the trigger is not application") - } if len(key) > MaxStorageKeyLen { return errors.New("key is too big") } @@ -423,7 +416,7 @@ func contractCallEx(ic *interop.Context, v *vm.VM) error { return contractCallExInternal(ic, v, h, method, args, flags) } -func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, _ smartcontract.CallFlag) error { +func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, f smartcontract.CallFlag) error { u, err := util.Uint160DecodeBytesBE(h) if err != nil { return errors.New("invalid contract hash") @@ -442,7 +435,7 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac return errors.New("disallowed method call") } } - v.LoadScript(cs.Script) + v.LoadScriptWithHash(cs.Script, u, v.Context().GetCallFlags()&f) v.Estack().PushVal(args) v.Estack().PushVal(method) return nil @@ -450,9 +443,6 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac // contractDestroy destroys a contract. func contractDestroy(ic *interop.Context, v *vm.VM) error { - if ic.Trigger != trigger.Application { - return errors.New("can't destroy contract when not triggered by application") - } hash := v.GetCurrentScriptHash() cs, err := ic.DAO.GetContractState(hash) if err != nil { diff --git a/pkg/core/interops.go b/pkg/core/interops.go index 819be06ab..8be1c16dc 100644 --- a/pkg/core/interops.go +++ b/pkg/core/interops.go @@ -16,6 +16,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/interop/iterator" "github.com/nspcc-dev/neo-go/pkg/core/interop/runtime" "github.com/nspcc-dev/neo-go/pkg/core/native" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" ) @@ -23,7 +25,7 @@ import ( // SpawnVM returns a VM with script getter and interop functions set // up for current blockchain. func SpawnVM(ic *interop.Context) *vm.VM { - vm := vm.New() + vm := vm.NewWithTrigger(ic.Trigger) vm.RegisterInteropGetter(getSystemInterop(ic)) vm.RegisterInteropGetter(getNeoInterop(ic)) if ic.Chain != nil { @@ -52,9 +54,13 @@ func getInteropFromSlice(ic *interop.Context, slice []interop.Function) func(uin return slice[i].ID >= id }) if n < len(slice) && slice[n].ID == id { - return &vm.InteropFuncPrice{Func: func(v *vm.VM) error { - return slice[n].Func(ic, v) - }, Price: slice[n].Price} + return &vm.InteropFuncPrice{ + Func: func(v *vm.VM) error { + return slice[n].Func(ic, v) + }, + Price: slice[n].Price, + RequiredFlags: slice[n].RequiredFlags, + } } return nil } @@ -62,17 +68,28 @@ func getInteropFromSlice(ic *interop.Context, slice []interop.Function) func(uin // All lists are sorted, keep 'em this way, please. var systemInterops = []interop.Function{ - {Name: "System.Blockchain.GetBlock", Func: bcGetBlock, Price: 250}, - {Name: "System.Blockchain.GetContract", Func: bcGetContract, Price: 100}, - {Name: "System.Blockchain.GetHeight", Func: bcGetHeight, Price: 1}, - {Name: "System.Blockchain.GetTransaction", Func: bcGetTransaction, Price: 100}, - {Name: "System.Blockchain.GetTransactionFromBlock", Func: bcGetTransactionFromBlock, Price: 100}, - {Name: "System.Blockchain.GetTransactionHeight", Func: bcGetTransactionHeight, Price: 100}, - {Name: "System.Contract.Call", Func: contractCall, Price: 1}, - {Name: "System.Contract.CallEx", Func: contractCallEx, Price: 1}, - {Name: "System.Contract.Create", Func: contractCreate, Price: 0}, - {Name: "System.Contract.Destroy", Func: contractDestroy, Price: 1}, - {Name: "System.Contract.Update", Func: contractUpdate, Price: 0}, + {Name: "System.Blockchain.GetBlock", Func: bcGetBlock, Price: 250, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Blockchain.GetContract", Func: bcGetContract, Price: 100, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Blockchain.GetHeight", Func: bcGetHeight, Price: 1, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Blockchain.GetTransaction", Func: bcGetTransaction, Price: 100, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Blockchain.GetTransactionFromBlock", Func: bcGetTransactionFromBlock, Price: 100, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Blockchain.GetTransactionHeight", Func: bcGetTransactionHeight, Price: 100, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Contract.Call", Func: contractCall, Price: 1, + AllowedTriggers: trigger.System | trigger.Application, RequiredFlags: smartcontract.AllowCall}, + {Name: "System.Contract.CallEx", Func: contractCallEx, Price: 1, + AllowedTriggers: trigger.System | trigger.Application, RequiredFlags: smartcontract.AllowCall}, + {Name: "System.Contract.Create", Func: contractCreate, Price: 0, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates}, + {Name: "System.Contract.Destroy", Func: contractDestroy, Price: 1, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates}, + {Name: "System.Contract.Update", Func: contractUpdate, Price: 0, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates}, {Name: "System.Enumerator.Concat", Func: enumerator.Concat, Price: 1}, {Name: "System.Enumerator.Create", Func: enumerator.Create, Price: 1}, {Name: "System.Enumerator.Next", Func: enumerator.Next, Price: 1}, @@ -86,29 +103,39 @@ var systemInterops = []interop.Function{ {Name: "System.Iterator.Key", Func: iterator.Key, Price: 1}, {Name: "System.Iterator.Keys", Func: iterator.Keys, Price: 1}, {Name: "System.Iterator.Values", Func: iterator.Values, Price: 1}, - {Name: "System.Runtime.CheckWitness", Func: runtime.CheckWitness, Price: 200}, + {Name: "System.Runtime.CheckWitness", Func: runtime.CheckWitness, Price: 200, RequiredFlags: smartcontract.AllowStates}, {Name: "System.Runtime.Deserialize", Func: runtimeDeserialize, Price: 1}, - {Name: "System.Runtime.GetTime", Func: runtimeGetTime, Price: 1}, + {Name: "System.Runtime.GetTime", Func: runtimeGetTime, Price: 1, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, {Name: "System.Runtime.GetTrigger", Func: runtimeGetTrigger, Price: 1}, - {Name: "System.Runtime.Log", Func: runtimeLog, Price: 1}, - {Name: "System.Runtime.Notify", Func: runtimeNotify, Price: 1}, + {Name: "System.Runtime.Log", Func: runtimeLog, Price: 1, RequiredFlags: smartcontract.AllowNotify}, + {Name: "System.Runtime.Notify", Func: runtimeNotify, Price: 1, RequiredFlags: smartcontract.AllowNotify}, {Name: "System.Runtime.Platform", Func: runtimePlatform, Price: 1}, {Name: "System.Runtime.Serialize", Func: runtimeSerialize, Price: 1}, - {Name: "System.Storage.Delete", Func: storageDelete, Price: 100}, - {Name: "System.Storage.Find", Func: storageFind, Price: 1}, - {Name: "System.Storage.Get", Func: storageGet, Price: 100}, - {Name: "System.Storage.GetContext", Func: storageGetContext, Price: 1}, - {Name: "System.Storage.GetReadOnlyContext", Func: storageGetReadOnlyContext, Price: 1}, - {Name: "System.Storage.Put", Func: storagePut, Price: 0}, // These don't have static price in C# code. - {Name: "System.Storage.PutEx", Func: storagePutEx, Price: 0}, - {Name: "System.Storage.AsReadOnly", Func: storageContextAsReadOnly, Price: 1}, + {Name: "System.Storage.Delete", Func: storageDelete, Price: 100, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates}, + {Name: "System.Storage.Find", Func: storageFind, Price: 1, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Storage.Get", Func: storageGet, Price: 100, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Storage.GetContext", Func: storageGetContext, Price: 1, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Storage.GetReadOnlyContext", Func: storageGetReadOnlyContext, Price: 1, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, + {Name: "System.Storage.Put", Func: storagePut, Price: 0, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates}, // These don't have static price in C# code. + {Name: "System.Storage.PutEx", Func: storagePutEx, Price: 0, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates}, + {Name: "System.Storage.AsReadOnly", Func: storageContextAsReadOnly, Price: 1, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates}, } var neoInterops = []interop.Function{ {Name: "Neo.Crypto.ECDsaVerify", Func: crypto.ECDSAVerify, Price: 1}, {Name: "Neo.Crypto.ECDsaCheckMultiSig", Func: crypto.ECDSACheckMultisig, Price: 1}, {Name: "Neo.Crypto.SHA256", Func: crypto.Sha256, Price: 1}, - {Name: "Neo.Native.Deploy", Func: native.Deploy, Price: 1}, + {Name: "Neo.Native.Deploy", Func: native.Deploy, Price: 1, + AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates}, } // initIDinInteropsSlice initializes IDs from names in one given diff --git a/pkg/core/native/contract.go b/pkg/core/native/contract.go index a16df5af8..34b5cf578 100644 --- a/pkg/core/native/contract.go +++ b/pkg/core/native/contract.go @@ -79,6 +79,9 @@ func getNativeInterop(ic *interop.Context, c interop.Contract) func(v *vm.VM) er if !ok { return fmt.Errorf("method %s not found", name) } + if !v.Context().GetCallFlags().Has(m.RequiredFlags) { + return errors.New("missing call flags") + } result := m.Func(ic, args) v.Estack().PushVal(result) return nil diff --git a/pkg/core/native/native_neo.go b/pkg/core/native/native_neo.go index 2dfd4b2f6..65024ee12 100644 --- a/pkg/core/native/native_neo.go +++ b/pkg/core/native/native_neo.go @@ -77,30 +77,30 @@ func NewNEO() *NEO { desc := newDescriptor("unclaimedGas", smartcontract.IntegerType, manifest.NewParameter("account", smartcontract.Hash160Type), manifest.NewParameter("end", smartcontract.IntegerType)) - md := newMethodAndPrice(n.unclaimedGas, 1, smartcontract.NoneFlag) + md := newMethodAndPrice(n.unclaimedGas, 1, smartcontract.AllowStates) n.AddMethod(md, desc, true) desc = newDescriptor("registerValidator", smartcontract.BoolType, manifest.NewParameter("pubkey", smartcontract.PublicKeyType)) - md = newMethodAndPrice(n.registerValidator, 1, smartcontract.NoneFlag) + md = newMethodAndPrice(n.registerValidator, 1, smartcontract.AllowModifyStates) n.AddMethod(md, desc, false) desc = newDescriptor("vote", smartcontract.BoolType, manifest.NewParameter("account", smartcontract.Hash160Type), manifest.NewParameter("pubkeys", smartcontract.ArrayType)) - md = newMethodAndPrice(n.vote, 1, smartcontract.NoneFlag) + md = newMethodAndPrice(n.vote, 1, smartcontract.AllowModifyStates) n.AddMethod(md, desc, false) desc = newDescriptor("getRegisteredValidators", smartcontract.ArrayType) - md = newMethodAndPrice(n.getRegisteredValidatorsCall, 1, smartcontract.NoneFlag) + md = newMethodAndPrice(n.getRegisteredValidatorsCall, 1, smartcontract.AllowStates) n.AddMethod(md, desc, true) desc = newDescriptor("getValidators", smartcontract.ArrayType) - md = newMethodAndPrice(n.getValidators, 1, smartcontract.NoneFlag) + md = newMethodAndPrice(n.getValidators, 1, smartcontract.AllowStates) n.AddMethod(md, desc, true) desc = newDescriptor("getNextBlockValidators", smartcontract.ArrayType) - md = newMethodAndPrice(n.getNextBlockValidators, 1, smartcontract.NoneFlag) + md = newMethodAndPrice(n.getNextBlockValidators, 1, smartcontract.AllowStates) n.AddMethod(md, desc, true) return n diff --git a/pkg/core/native/native_nep5.go b/pkg/core/native/native_nep5.go index 66d9fbd3d..ada5bbf43 100644 --- a/pkg/core/native/native_nep5.go +++ b/pkg/core/native/native_nep5.go @@ -61,12 +61,12 @@ func newNEP5Native(name string) *nep5TokenNative { n.AddMethod(md, desc, true) desc = newDescriptor("totalSupply", smartcontract.IntegerType) - md = newMethodAndPrice(n.TotalSupply, 1, smartcontract.NoneFlag) + md = newMethodAndPrice(n.TotalSupply, 1, smartcontract.AllowStates) n.AddMethod(md, desc, true) desc = newDescriptor("balanceOf", smartcontract.IntegerType, manifest.NewParameter("account", smartcontract.Hash160Type)) - md = newMethodAndPrice(n.balanceOf, 1, smartcontract.NoneFlag) + md = newMethodAndPrice(n.balanceOf, 1, smartcontract.AllowStates) n.AddMethod(md, desc, true) desc = newDescriptor("transfer", smartcontract.BoolType, @@ -74,7 +74,7 @@ func newNEP5Native(name string) *nep5TokenNative { manifest.NewParameter("to", smartcontract.Hash160Type), manifest.NewParameter("amount", smartcontract.IntegerType), ) - md = newMethodAndPrice(n.Transfer, 1, smartcontract.NoneFlag) + md = newMethodAndPrice(n.Transfer, 1, smartcontract.AllowModifyStates) n.AddMethod(md, desc, false) n.AddEvent("Transfer", desc.Parameters...) diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index c75af6fad..009f097af 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -915,7 +915,7 @@ func (s *Server) invokescript(reqParams request.Params) (interface{}, *response. func (s *Server) runScriptInVM(script []byte) *result.Invoke { vm := s.chain.GetTestVM() vm.SetGasLimit(s.config.MaxGasInvoke) - vm.LoadScript(script) + vm.LoadScriptWithFlags(script, smartcontract.All) _ = vm.Run() result := &result.Invoke{ State: vm.State(), diff --git a/pkg/smartcontract/call_flags.go b/pkg/smartcontract/call_flags.go index 9bc042fd5..853448b87 100644 --- a/pkg/smartcontract/call_flags.go +++ b/pkg/smartcontract/call_flags.go @@ -5,10 +5,16 @@ type CallFlag byte // Default flags. const ( - NoneFlag CallFlag = 0 - AllowModifyStates CallFlag = 1 << iota + AllowStates CallFlag = 1 << iota + AllowModifyStates AllowCall AllowNotify - ReadOnly = AllowCall | AllowNotify - All = AllowModifyStates | AllowCall | AllowNotify + ReadOnly = AllowStates | AllowCall | AllowNotify + All = ReadOnly | AllowModifyStates + NoneFlag CallFlag = 0 ) + +// Has returns true iff all bits set in cf are also set in f. +func (f CallFlag) Has(cf CallFlag) bool { + return f&cf == cf +} diff --git a/pkg/smartcontract/call_flags_test.go b/pkg/smartcontract/call_flags_test.go new file mode 100644 index 000000000..f50f12ce5 --- /dev/null +++ b/pkg/smartcontract/call_flags_test.go @@ -0,0 +1,14 @@ +package smartcontract + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCallFlag_Has(t *testing.T) { + require.True(t, AllowCall.Has(AllowCall)) + require.True(t, (AllowCall | AllowNotify).Has(AllowCall)) + require.False(t, (AllowCall).Has(AllowCall|AllowNotify)) + require.True(t, All.Has(ReadOnly)) +} diff --git a/pkg/smartcontract/trigger/trigger_type.go b/pkg/smartcontract/trigger/trigger_type.go index 80f499b13..70c1a221a 100644 --- a/pkg/smartcontract/trigger/trigger_type.go +++ b/pkg/smartcontract/trigger/trigger_type.go @@ -1,41 +1,29 @@ package trigger -//go:generate stringer -type=Type +//go:generate stringer -type=Type -output=trigger_type_string.go // Type represents trigger type used in C# reference node: https://github.com/neo-project/neo/blob/c64748ecbac3baeb8045b16af0d518398a6ced24/neo/SmartContract/TriggerType.cs#L3 type Type byte // Viable list of supported trigger type constants. const ( + // System is trigger type that indicates that script is being invoke internally by the system. + System Type = 0x01 + // The verification trigger indicates that the contract is being invoked as a verification function. // The verification function can accept multiple parameters, and should return a boolean value that indicates the validity of the transaction or block. // The entry point of the contract will be invoked if the contract is triggered by Verification: // main(...); // The entry point of the contract must be able to handle this type of invocation. - Verification Type = 0x00 - - // The verificationR trigger indicates that the contract is being invoked as a verification function because it is specified as a target of an output of the transaction. - // The verification function accepts no parameter, and should return a boolean value that indicates the validity of the transaction. - // The entry point of the contract will be invoked if the contract is triggered by VerificationR: - // main("receiving", new object[0]); - // The receiving function should have the following signature: - // public bool receiving() - // The receiving function will be invoked automatically when a contract is receiving assets from a transfer. - VerificationR Type = 0x01 + Verification Type = 0x20 // The application trigger indicates that the contract is being invoked as an application function. // The application function can accept multiple parameters, change the states of the blockchain, and return any type of value. // The contract can have any form of entry point, but we recommend that all contracts should have the following entry point: // public byte[] main(string operation, params object[] args) // The functions can be invoked by creating an InvocationTransaction. - Application Type = 0x10 + Application Type = 0x40 - // The ApplicationR trigger indicates that the default function received of the contract is being invoked because it is specified as a target of an output of the transaction. - // The received function accepts no parameter, changes the states of the blockchain, and returns any type of value. - // The entry point of the contract will be invoked if the contract is triggered by ApplicationR: - // main("received", new object[0]); - // The received function should have the following signature: - // public byte[] received() - // The received function will be invoked automatically when a contract is receiving assets from a transfer. - ApplicationR Type = 0x11 + // All represents any trigger type. + All = System | Verification | Application ) diff --git a/pkg/smartcontract/trigger/trigger_type_string.go b/pkg/smartcontract/trigger/trigger_type_string.go index 298846f94..b91698c75 100644 --- a/pkg/smartcontract/trigger/trigger_type_string.go +++ b/pkg/smartcontract/trigger/trigger_type_string.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=Type"; DO NOT EDIT. +// Code generated by "stringer -type=Type -output=trigger_type_string.go"; DO NOT EDIT. package trigger @@ -8,29 +8,25 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[Verification-0] - _ = x[VerificationR-1] - _ = x[Application-16] - _ = x[ApplicationR-17] + _ = x[System-1] + _ = x[Verification-32] + _ = x[Application-64] } const ( - _Type_name_0 = "VerificationVerificationR" - _Type_name_1 = "ApplicationApplicationR" -) - -var ( - _Type_index_0 = [...]uint8{0, 12, 25} - _Type_index_1 = [...]uint8{0, 11, 23} + _Type_name_0 = "System" + _Type_name_1 = "Verification" + _Type_name_2 = "Application" ) func (i Type) String() string { switch { - case i <= 1: - return _Type_name_0[_Type_index_0[i]:_Type_index_0[i+1]] - case 16 <= i && i <= 17: - i -= 16 - return _Type_name_1[_Type_index_1[i]:_Type_index_1[i+1]] + case i == 1: + return _Type_name_0 + case i == 32: + return _Type_name_1 + case i == 64: + return _Type_name_2 default: return "Type(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/smartcontract/trigger/trigger_type_test.go b/pkg/smartcontract/trigger/trigger_type_test.go index 211b7436d..bf3bd9976 100644 --- a/pkg/smartcontract/trigger/trigger_type_test.go +++ b/pkg/smartcontract/trigger/trigger_type_test.go @@ -8,10 +8,9 @@ import ( func TestStringer(t *testing.T) { tests := map[Type]string{ - Application: "Application", - ApplicationR: "ApplicationR", - Verification: "Verification", - VerificationR: "VerificationR", + System: "System", + Application: "Application", + Verification: "Verification", } for o, s := range tests { assert.Equal(t, s, o.String()) @@ -20,10 +19,9 @@ func TestStringer(t *testing.T) { func TestEncodeBynary(t *testing.T) { tests := map[Type]byte{ - Verification: 0x00, - VerificationR: 0x01, - Application: 0x10, - ApplicationR: 0x11, + System: 0x01, + Verification: 0x20, + Application: 0x40, } for o, b := range tests { assert.Equal(t, b, byte(o)) @@ -32,10 +30,9 @@ func TestEncodeBynary(t *testing.T) { func TestDecodeBynary(t *testing.T) { tests := map[Type]byte{ - Verification: 0x00, - VerificationR: 0x01, - Application: 0x10, - ApplicationR: 0x11, + System: 0x01, + Verification: 0x20, + Application: 0x40, } for o, b := range tests { assert.Equal(t, o, Type(b)) diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 83e385d19..49ba04353 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -6,6 +6,7 @@ import ( "math/big" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -39,6 +40,9 @@ type Context struct { // Script hash of the prog. scriptHash util.Uint160 + + // Call flags this context was created with. + callFlag smartcontract.CallFlag } var errNoInstParam = errors.New("failed to read instruction parameter") @@ -154,6 +158,11 @@ func (c *Context) Copy() *Context { return ctx } +// GetCallFlags returns calling flags context was created with. +func (c *Context) GetCallFlags() smartcontract.CallFlag { + return c.callFlag +} + // Program returns the loaded program. func (c *Context) Program() []byte { return c.prog diff --git a/pkg/vm/interop.go b/pkg/vm/interop.go index 4024a09f4..2d8fb848c 100644 --- a/pkg/vm/interop.go +++ b/pkg/vm/interop.go @@ -5,6 +5,8 @@ import ( "fmt" "sort" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -16,6 +18,10 @@ type InteropFunc func(vm *VM) error type InteropFuncPrice struct { Func InteropFunc Price int + // AllowedTriggers is a mask representing triggers which should be allowed by an interop. + // 0 is interpreted as All. + AllowedTriggers trigger.Type + RequiredFlags smartcontract.CallFlag } // interopIDFuncPrice adds an ID to the InteropFuncPrice. @@ -30,31 +36,31 @@ type InteropGetterFunc func(uint32) *InteropFuncPrice var defaultVMInterops = []interopIDFuncPrice{ {emit.InteropNameToID([]byte("System.Runtime.Log")), - InteropFuncPrice{runtimeLog, 1}}, + InteropFuncPrice{Func: runtimeLog, Price: 1}}, {emit.InteropNameToID([]byte("System.Runtime.Notify")), - InteropFuncPrice{runtimeNotify, 1}}, + InteropFuncPrice{Func: runtimeNotify, Price: 1}}, {emit.InteropNameToID([]byte("System.Runtime.Serialize")), - InteropFuncPrice{RuntimeSerialize, 1}}, + InteropFuncPrice{Func: RuntimeSerialize, Price: 1}}, {emit.InteropNameToID([]byte("System.Runtime.Deserialize")), - InteropFuncPrice{RuntimeDeserialize, 1}}, + InteropFuncPrice{Func: RuntimeDeserialize, Price: 1}}, {emit.InteropNameToID([]byte("System.Enumerator.Create")), - InteropFuncPrice{EnumeratorCreate, 1}}, + InteropFuncPrice{Func: EnumeratorCreate, Price: 1}}, {emit.InteropNameToID([]byte("System.Enumerator.Next")), - InteropFuncPrice{EnumeratorNext, 1}}, + InteropFuncPrice{Func: EnumeratorNext, Price: 1}}, {emit.InteropNameToID([]byte("System.Enumerator.Concat")), - InteropFuncPrice{EnumeratorConcat, 1}}, + InteropFuncPrice{Func: EnumeratorConcat, Price: 1}}, {emit.InteropNameToID([]byte("System.Enumerator.Value")), - InteropFuncPrice{EnumeratorValue, 1}}, + InteropFuncPrice{Func: EnumeratorValue, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Create")), - InteropFuncPrice{IteratorCreate, 1}}, + InteropFuncPrice{Func: IteratorCreate, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Concat")), - InteropFuncPrice{IteratorConcat, 1}}, + InteropFuncPrice{Func: IteratorConcat, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Key")), - InteropFuncPrice{IteratorKey, 1}}, + InteropFuncPrice{Func: IteratorKey, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Keys")), - InteropFuncPrice{IteratorKeys, 1}}, + InteropFuncPrice{Func: IteratorKeys, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Values")), - InteropFuncPrice{IteratorValues, 1}}, + InteropFuncPrice{Func: IteratorValues, Price: 1}}, } func getDefaultVMInterop(id uint32) *InteropFuncPrice { diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index ee2823d49..614aec000 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -17,6 +17,8 @@ import ( "testing" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/require" @@ -111,11 +113,23 @@ func TestUT(t *testing.T) { } func getTestingInterop(id uint32) *InteropFuncPrice { - if id == binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}) { - return &InteropFuncPrice{InteropFunc(func(v *VM) error { - v.estack.PushVal(stackitem.NewInterop(new(int))) - return nil - }), 0} + f := func(v *VM) error { + v.estack.PushVal(stackitem.NewInterop(new(int))) + return nil + } + switch id { + case binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}): + return &InteropFuncPrice{Func: f} + case binary.LittleEndian.Uint32([]byte{0x66, 0x66, 0x66, 0x66}): + return &InteropFuncPrice{ + Func: f, + RequiredFlags: smartcontract.ReadOnly, + } + case binary.LittleEndian.Uint32([]byte{0x55, 0x55, 0x55, 0x55}): + return &InteropFuncPrice{ + Func: f, + AllowedTriggers: trigger.Application, + } } return nil } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 7503e5869..9cb456583 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -13,6 +13,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -84,18 +86,26 @@ type VM struct { gasConsumed util.Fixed8 gasLimit util.Fixed8 + trigger trigger.Type + // Public keys cache. keys map[string]*keys.PublicKey } // New returns a new VM object ready to load .avm bytecode scripts. func New() *VM { + return NewWithTrigger(trigger.System) +} + +// NewWithTrigger returns a new VM for executions triggered by t. +func NewWithTrigger(t trigger.Type) *VM { vm := &VM{ getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage. state: haltState, istack: NewStack("invocation"), refs: newRefCounter(), keys: make(map[string]*keys.PublicKey), + trigger: t, } vm.estack = vm.newItemStack("evaluation") @@ -262,12 +272,27 @@ func (v *VM) Load(prog []byte) { // will immediately push a new context created from this script to // the invocation stack and starts executing it. func (v *VM) LoadScript(b []byte) { + v.LoadScriptWithFlags(b, smartcontract.NoneFlag) +} + +// LoadScriptWithFlags loads script and sets call flag to f. +func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) { ctx := NewContext(b) ctx.estack = v.estack ctx.astack = v.astack + ctx.callFlag = f v.istack.PushVal(ctx) } +// LoadScriptWithHash if similar to the LoadScriptWithFlags method, but it also loads +// given script hash directly into the Context to avoid its recalculations. It's +// up to user of this function to make sure the script and hash match each other. +func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) { + v.LoadScriptWithFlags(b, f) + ctx := v.Context() + ctx.scriptHash = hash +} + // Context returns the current executed context. Nil if there is no context, // which implies no program is loaded. func (v *VM) Context() *Context { @@ -1244,6 +1269,12 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.SYSCALL: interopID := GetInteropID(parameter) ifunc := v.GetInteropByID(interopID) + if ifunc.AllowedTriggers != 0 && ifunc.AllowedTriggers&v.trigger == 0 { + panic(fmt.Sprintf("trigger not allowed: %s", v.trigger)) + } + if !v.Context().callFlag.Has(ifunc.RequiredFlags) { + panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags)) + } if ifunc == nil { panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID)) diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index ae9a06fd0..17ba7c72d 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -12,6 +12,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -22,10 +24,13 @@ import ( func fooInteropGetter(id uint32) *InteropFuncPrice { if id == emit.InteropNameToID([]byte("foo")) { - return &InteropFuncPrice{func(evm *VM) error { - evm.Estack().PushVal(1) - return nil - }, 1} + return &InteropFuncPrice{ + Func: func(evm *VM) error { + evm.Estack().PushVal(1) + return nil + }, + Price: 1, + } } return nil } @@ -812,6 +817,54 @@ func TestSerializeInterop(t *testing.T) { require.True(t, vm.HasFailed()) } +func getTestCallFlagsFunc(syscall []byte, flags smartcontract.CallFlag, result interface{}) func(t *testing.T) { + return func(t *testing.T) { + script := append([]byte{byte(opcode.SYSCALL)}, syscall...) + v := New() + v.RegisterInteropGetter(getTestingInterop) + v.LoadScriptWithFlags(script, flags) + if result == nil { + checkVMFailed(t, v) + return + } + runVM(t, v) + require.Equal(t, result, v.PopResult()) + } +} + +func TestCallFlags(t *testing.T) { + noFlags := []byte{0x77, 0x77, 0x77, 0x77} + readOnly := []byte{0x66, 0x66, 0x66, 0x66} + t.Run("NoFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.NoneFlag, new(int))) + t.Run("ProvideFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.AllowCall, new(int))) + t.Run("NoFlagsSomeRequired", getTestCallFlagsFunc(readOnly, smartcontract.NoneFlag, nil)) + t.Run("OnlyOneProvided", getTestCallFlagsFunc(readOnly, smartcontract.AllowCall, nil)) + t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int))) +} + +func getTestTriggerFunc(syscall []byte, tr trigger.Type, result interface{}) func(t *testing.T) { + return func(t *testing.T) { + script := append([]byte{byte(opcode.SYSCALL)}, syscall...) + v := NewWithTrigger(tr) + v.RegisterInteropGetter(getTestingInterop) + v.LoadScript(script) + if result == nil { + checkVMFailed(t, v) + return + } + runVM(t, v) + require.Equal(t, result, v.PopResult()) + } +} + +func TestAllowedTriggers(t *testing.T) { + noFlags := []byte{0x77, 0x77, 0x77, 0x77} + appOnly := []byte{0x55, 0x55, 0x55, 0x55} + t.Run("Application/NeedNothing", getTestTriggerFunc(noFlags, trigger.Application, new(int))) + t.Run("Application/NeedApplication", getTestTriggerFunc(appOnly, trigger.Application, new(int))) + t.Run("System/NeedApplication", getTestTriggerFunc(appOnly, trigger.System, nil)) +} + func callNTimes(n uint16) []byte { return makeProgram( opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian