Merge pull request #1044 from nspcc-dev/feature/interop_flags

Provide required call flags and allowed triggers in interop descriptions.
This commit is contained in:
Roman Khimov 2020-06-11 16:24:58 +03:00 committed by GitHub
commit 3ef35e0fc7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 269 additions and 132 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
@ -64,7 +65,7 @@ func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM {
b, err := compiler.Compile(strings.NewReader(src)) b, err := compiler.Compile(strings.NewReader(src))
require.NoError(t, err) require.NoError(t, err)
v := core.SpawnVM(ic) v := core.SpawnVM(ic)
v.Load(b) v.LoadScriptWithFlags(b, smartcontract.All)
return v return v
} }

View file

@ -20,6 +20,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
@ -562,7 +563,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
systemInterop := bc.newInteropContext(trigger.Application, cache, block, tx) systemInterop := bc.newInteropContext(trigger.Application, cache, block, tx)
v := SpawnVM(systemInterop) v := SpawnVM(systemInterop)
v.LoadScript(tx.Script) v.LoadScriptWithFlags(tx.Script, smartcontract.All)
v.SetPriceGetter(getPrice) v.SetPriceGetter(getPrice)
if bc.config.FreeGasLimit > 0 { if bc.config.FreeGasLimit > 0 {
v.SetGasLimit(bc.config.FreeGasLimit + tx.SystemFee) v.SetGasLimit(bc.config.FreeGasLimit + tx.SystemFee)
@ -1276,7 +1277,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa
} }
vm := SpawnVM(interopCtx) vm := SpawnVM(interopCtx)
vm.LoadScript(verification) vm.LoadScriptWithFlags(verification, smartcontract.ReadOnly)
vm.LoadScript(witness.InvocationScript) vm.LoadScript(witness.InvocationScript)
if useKeys { if useKeys {
bc.keyCacheLock.RLock() bc.keyCacheLock.RLock()

View file

@ -56,6 +56,11 @@ type Function struct {
Name string Name string
Func func(*Context, *vm.VM) error Func func(*Context, *vm.VM) error
Price int Price int
// AllowedTriggers is a set of triggers which are allowed to initiate invocation.
AllowedTriggers trigger.Type
// RequiredFlags is a set of flags which must be set during script invocations.
// Default value is NoneFlag i.e. no flags are required.
RequiredFlags smartcontract.CallFlag
} }
// Method is a signature for a native method. // Method is a signature for a native method.

View file

@ -10,7 +10,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -63,9 +62,6 @@ func storageFind(ic *interop.Context, v *vm.VM) error {
// evaluation stack, does a lot of checks and returns Contract if it // evaluation stack, does a lot of checks and returns Contract if it
// succeeds. // succeeds.
func createContractStateFromVM(ic *interop.Context, v *vm.VM) (*state.Contract, error) { func createContractStateFromVM(ic *interop.Context, v *vm.VM) (*state.Contract, error) {
if ic.Trigger != trigger.Application {
return nil, errors.New("can't create contract when not triggered by an application")
}
script := v.Estack().Pop().Bytes() script := v.Estack().Pop().Bytes()
if len(script) > MaxContractScriptSize { if len(script) > MaxContractScriptSize {
return nil, errors.New("the script is too big") return nil, errors.New("the script is too big")

View file

@ -13,7 +13,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -272,9 +271,6 @@ func checkStorageContext(ic *interop.Context, stc *StorageContext) error {
// storageDelete deletes stored key-value pair. // storageDelete deletes stored key-value pair.
func storageDelete(ic *interop.Context, v *vm.VM) error { func storageDelete(ic *interop.Context, v *vm.VM) error {
if ic.Trigger != trigger.Application && ic.Trigger != trigger.ApplicationR {
return errors.New("can't delete when the trigger is not application")
}
stcInterface := v.Estack().Pop().Value() stcInterface := v.Estack().Pop().Value()
stc, ok := stcInterface.(*StorageContext) stc, ok := stcInterface.(*StorageContext)
if !ok { if !ok {
@ -337,9 +333,6 @@ func storageGetReadOnlyContext(ic *interop.Context, v *vm.VM) error {
} }
func putWithContextAndFlags(ic *interop.Context, stc *StorageContext, key []byte, value []byte, isConst bool) error { func putWithContextAndFlags(ic *interop.Context, stc *StorageContext, key []byte, value []byte, isConst bool) error {
if ic.Trigger != trigger.Application && ic.Trigger != trigger.ApplicationR {
return errors.New("can't delete when the trigger is not application")
}
if len(key) > MaxStorageKeyLen { if len(key) > MaxStorageKeyLen {
return errors.New("key is too big") return errors.New("key is too big")
} }
@ -423,7 +416,7 @@ func contractCallEx(ic *interop.Context, v *vm.VM) error {
return contractCallExInternal(ic, v, h, method, args, flags) return contractCallExInternal(ic, v, h, method, args, flags)
} }
func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, _ smartcontract.CallFlag) error { func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, f smartcontract.CallFlag) error {
u, err := util.Uint160DecodeBytesBE(h) u, err := util.Uint160DecodeBytesBE(h)
if err != nil { if err != nil {
return errors.New("invalid contract hash") return errors.New("invalid contract hash")
@ -442,7 +435,7 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
return errors.New("disallowed method call") return errors.New("disallowed method call")
} }
} }
v.LoadScript(cs.Script) v.LoadScriptWithHash(cs.Script, u, v.Context().GetCallFlags()&f)
v.Estack().PushVal(args) v.Estack().PushVal(args)
v.Estack().PushVal(method) v.Estack().PushVal(method)
return nil return nil
@ -450,9 +443,6 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
// contractDestroy destroys a contract. // contractDestroy destroys a contract.
func contractDestroy(ic *interop.Context, v *vm.VM) error { func contractDestroy(ic *interop.Context, v *vm.VM) error {
if ic.Trigger != trigger.Application {
return errors.New("can't destroy contract when not triggered by application")
}
hash := v.GetCurrentScriptHash() hash := v.GetCurrentScriptHash()
cs, err := ic.DAO.GetContractState(hash) cs, err := ic.DAO.GetContractState(hash)
if err != nil { if err != nil {

View file

@ -16,6 +16,8 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/interop/iterator" "github.com/nspcc-dev/neo-go/pkg/core/interop/iterator"
"github.com/nspcc-dev/neo-go/pkg/core/interop/runtime" "github.com/nspcc-dev/neo-go/pkg/core/interop/runtime"
"github.com/nspcc-dev/neo-go/pkg/core/native" "github.com/nspcc-dev/neo-go/pkg/core/native"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
) )
@ -23,7 +25,7 @@ import (
// SpawnVM returns a VM with script getter and interop functions set // SpawnVM returns a VM with script getter and interop functions set
// up for current blockchain. // up for current blockchain.
func SpawnVM(ic *interop.Context) *vm.VM { func SpawnVM(ic *interop.Context) *vm.VM {
vm := vm.New() vm := vm.NewWithTrigger(ic.Trigger)
vm.RegisterInteropGetter(getSystemInterop(ic)) vm.RegisterInteropGetter(getSystemInterop(ic))
vm.RegisterInteropGetter(getNeoInterop(ic)) vm.RegisterInteropGetter(getNeoInterop(ic))
if ic.Chain != nil { if ic.Chain != nil {
@ -52,9 +54,13 @@ func getInteropFromSlice(ic *interop.Context, slice []interop.Function) func(uin
return slice[i].ID >= id return slice[i].ID >= id
}) })
if n < len(slice) && slice[n].ID == id { if n < len(slice) && slice[n].ID == id {
return &vm.InteropFuncPrice{Func: func(v *vm.VM) error { return &vm.InteropFuncPrice{
return slice[n].Func(ic, v) Func: func(v *vm.VM) error {
}, Price: slice[n].Price} return slice[n].Func(ic, v)
},
Price: slice[n].Price,
RequiredFlags: slice[n].RequiredFlags,
}
} }
return nil return nil
} }
@ -62,17 +68,28 @@ func getInteropFromSlice(ic *interop.Context, slice []interop.Function) func(uin
// All lists are sorted, keep 'em this way, please. // All lists are sorted, keep 'em this way, please.
var systemInterops = []interop.Function{ var systemInterops = []interop.Function{
{Name: "System.Blockchain.GetBlock", Func: bcGetBlock, Price: 250}, {Name: "System.Blockchain.GetBlock", Func: bcGetBlock, Price: 250,
{Name: "System.Blockchain.GetContract", Func: bcGetContract, Price: 100}, AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Blockchain.GetHeight", Func: bcGetHeight, Price: 1}, {Name: "System.Blockchain.GetContract", Func: bcGetContract, Price: 100,
{Name: "System.Blockchain.GetTransaction", Func: bcGetTransaction, Price: 100}, AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Blockchain.GetTransactionFromBlock", Func: bcGetTransactionFromBlock, Price: 100}, {Name: "System.Blockchain.GetHeight", Func: bcGetHeight, Price: 1,
{Name: "System.Blockchain.GetTransactionHeight", Func: bcGetTransactionHeight, Price: 100}, AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Contract.Call", Func: contractCall, Price: 1}, {Name: "System.Blockchain.GetTransaction", Func: bcGetTransaction, Price: 100,
{Name: "System.Contract.CallEx", Func: contractCallEx, Price: 1}, AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Contract.Create", Func: contractCreate, Price: 0}, {Name: "System.Blockchain.GetTransactionFromBlock", Func: bcGetTransactionFromBlock, Price: 100,
{Name: "System.Contract.Destroy", Func: contractDestroy, Price: 1}, AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Contract.Update", Func: contractUpdate, Price: 0}, {Name: "System.Blockchain.GetTransactionHeight", Func: bcGetTransactionHeight, Price: 100,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Contract.Call", Func: contractCall, Price: 1,
AllowedTriggers: trigger.System | trigger.Application, RequiredFlags: smartcontract.AllowCall},
{Name: "System.Contract.CallEx", Func: contractCallEx, Price: 1,
AllowedTriggers: trigger.System | trigger.Application, RequiredFlags: smartcontract.AllowCall},
{Name: "System.Contract.Create", Func: contractCreate, Price: 0,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates},
{Name: "System.Contract.Destroy", Func: contractDestroy, Price: 1,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates},
{Name: "System.Contract.Update", Func: contractUpdate, Price: 0,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates},
{Name: "System.Enumerator.Concat", Func: enumerator.Concat, Price: 1}, {Name: "System.Enumerator.Concat", Func: enumerator.Concat, Price: 1},
{Name: "System.Enumerator.Create", Func: enumerator.Create, Price: 1}, {Name: "System.Enumerator.Create", Func: enumerator.Create, Price: 1},
{Name: "System.Enumerator.Next", Func: enumerator.Next, Price: 1}, {Name: "System.Enumerator.Next", Func: enumerator.Next, Price: 1},
@ -86,29 +103,39 @@ var systemInterops = []interop.Function{
{Name: "System.Iterator.Key", Func: iterator.Key, Price: 1}, {Name: "System.Iterator.Key", Func: iterator.Key, Price: 1},
{Name: "System.Iterator.Keys", Func: iterator.Keys, Price: 1}, {Name: "System.Iterator.Keys", Func: iterator.Keys, Price: 1},
{Name: "System.Iterator.Values", Func: iterator.Values, Price: 1}, {Name: "System.Iterator.Values", Func: iterator.Values, Price: 1},
{Name: "System.Runtime.CheckWitness", Func: runtime.CheckWitness, Price: 200}, {Name: "System.Runtime.CheckWitness", Func: runtime.CheckWitness, Price: 200, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Runtime.Deserialize", Func: runtimeDeserialize, Price: 1}, {Name: "System.Runtime.Deserialize", Func: runtimeDeserialize, Price: 1},
{Name: "System.Runtime.GetTime", Func: runtimeGetTime, Price: 1}, {Name: "System.Runtime.GetTime", Func: runtimeGetTime, Price: 1,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Runtime.GetTrigger", Func: runtimeGetTrigger, Price: 1}, {Name: "System.Runtime.GetTrigger", Func: runtimeGetTrigger, Price: 1},
{Name: "System.Runtime.Log", Func: runtimeLog, Price: 1}, {Name: "System.Runtime.Log", Func: runtimeLog, Price: 1, RequiredFlags: smartcontract.AllowNotify},
{Name: "System.Runtime.Notify", Func: runtimeNotify, Price: 1}, {Name: "System.Runtime.Notify", Func: runtimeNotify, Price: 1, RequiredFlags: smartcontract.AllowNotify},
{Name: "System.Runtime.Platform", Func: runtimePlatform, Price: 1}, {Name: "System.Runtime.Platform", Func: runtimePlatform, Price: 1},
{Name: "System.Runtime.Serialize", Func: runtimeSerialize, Price: 1}, {Name: "System.Runtime.Serialize", Func: runtimeSerialize, Price: 1},
{Name: "System.Storage.Delete", Func: storageDelete, Price: 100}, {Name: "System.Storage.Delete", Func: storageDelete, Price: 100,
{Name: "System.Storage.Find", Func: storageFind, Price: 1}, AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates},
{Name: "System.Storage.Get", Func: storageGet, Price: 100}, {Name: "System.Storage.Find", Func: storageFind, Price: 1,
{Name: "System.Storage.GetContext", Func: storageGetContext, Price: 1}, AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Storage.GetReadOnlyContext", Func: storageGetReadOnlyContext, Price: 1}, {Name: "System.Storage.Get", Func: storageGet, Price: 100,
{Name: "System.Storage.Put", Func: storagePut, Price: 0}, // These don't have static price in C# code. AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Storage.PutEx", Func: storagePutEx, Price: 0}, {Name: "System.Storage.GetContext", Func: storageGetContext, Price: 1,
{Name: "System.Storage.AsReadOnly", Func: storageContextAsReadOnly, Price: 1}, AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Storage.GetReadOnlyContext", Func: storageGetReadOnlyContext, Price: 1,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
{Name: "System.Storage.Put", Func: storagePut, Price: 0,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates}, // These don't have static price in C# code.
{Name: "System.Storage.PutEx", Func: storagePutEx, Price: 0,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates},
{Name: "System.Storage.AsReadOnly", Func: storageContextAsReadOnly, Price: 1,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowStates},
} }
var neoInterops = []interop.Function{ var neoInterops = []interop.Function{
{Name: "Neo.Crypto.ECDsaVerify", Func: crypto.ECDSAVerify, Price: 1}, {Name: "Neo.Crypto.ECDsaVerify", Func: crypto.ECDSAVerify, Price: 1},
{Name: "Neo.Crypto.ECDsaCheckMultiSig", Func: crypto.ECDSACheckMultisig, Price: 1}, {Name: "Neo.Crypto.ECDsaCheckMultiSig", Func: crypto.ECDSACheckMultisig, Price: 1},
{Name: "Neo.Crypto.SHA256", Func: crypto.Sha256, Price: 1}, {Name: "Neo.Crypto.SHA256", Func: crypto.Sha256, Price: 1},
{Name: "Neo.Native.Deploy", Func: native.Deploy, Price: 1}, {Name: "Neo.Native.Deploy", Func: native.Deploy, Price: 1,
AllowedTriggers: trigger.Application, RequiredFlags: smartcontract.AllowModifyStates},
} }
// initIDinInteropsSlice initializes IDs from names in one given // initIDinInteropsSlice initializes IDs from names in one given

View file

@ -79,6 +79,9 @@ func getNativeInterop(ic *interop.Context, c interop.Contract) func(v *vm.VM) er
if !ok { if !ok {
return fmt.Errorf("method %s not found", name) return fmt.Errorf("method %s not found", name)
} }
if !v.Context().GetCallFlags().Has(m.RequiredFlags) {
return errors.New("missing call flags")
}
result := m.Func(ic, args) result := m.Func(ic, args)
v.Estack().PushVal(result) v.Estack().PushVal(result)
return nil return nil

View file

@ -77,30 +77,30 @@ func NewNEO() *NEO {
desc := newDescriptor("unclaimedGas", smartcontract.IntegerType, desc := newDescriptor("unclaimedGas", smartcontract.IntegerType,
manifest.NewParameter("account", smartcontract.Hash160Type), manifest.NewParameter("account", smartcontract.Hash160Type),
manifest.NewParameter("end", smartcontract.IntegerType)) manifest.NewParameter("end", smartcontract.IntegerType))
md := newMethodAndPrice(n.unclaimedGas, 1, smartcontract.NoneFlag) md := newMethodAndPrice(n.unclaimedGas, 1, smartcontract.AllowStates)
n.AddMethod(md, desc, true) n.AddMethod(md, desc, true)
desc = newDescriptor("registerValidator", smartcontract.BoolType, desc = newDescriptor("registerValidator", smartcontract.BoolType,
manifest.NewParameter("pubkey", smartcontract.PublicKeyType)) manifest.NewParameter("pubkey", smartcontract.PublicKeyType))
md = newMethodAndPrice(n.registerValidator, 1, smartcontract.NoneFlag) md = newMethodAndPrice(n.registerValidator, 1, smartcontract.AllowModifyStates)
n.AddMethod(md, desc, false) n.AddMethod(md, desc, false)
desc = newDescriptor("vote", smartcontract.BoolType, desc = newDescriptor("vote", smartcontract.BoolType,
manifest.NewParameter("account", smartcontract.Hash160Type), manifest.NewParameter("account", smartcontract.Hash160Type),
manifest.NewParameter("pubkeys", smartcontract.ArrayType)) manifest.NewParameter("pubkeys", smartcontract.ArrayType))
md = newMethodAndPrice(n.vote, 1, smartcontract.NoneFlag) md = newMethodAndPrice(n.vote, 1, smartcontract.AllowModifyStates)
n.AddMethod(md, desc, false) n.AddMethod(md, desc, false)
desc = newDescriptor("getRegisteredValidators", smartcontract.ArrayType) desc = newDescriptor("getRegisteredValidators", smartcontract.ArrayType)
md = newMethodAndPrice(n.getRegisteredValidatorsCall, 1, smartcontract.NoneFlag) md = newMethodAndPrice(n.getRegisteredValidatorsCall, 1, smartcontract.AllowStates)
n.AddMethod(md, desc, true) n.AddMethod(md, desc, true)
desc = newDescriptor("getValidators", smartcontract.ArrayType) desc = newDescriptor("getValidators", smartcontract.ArrayType)
md = newMethodAndPrice(n.getValidators, 1, smartcontract.NoneFlag) md = newMethodAndPrice(n.getValidators, 1, smartcontract.AllowStates)
n.AddMethod(md, desc, true) n.AddMethod(md, desc, true)
desc = newDescriptor("getNextBlockValidators", smartcontract.ArrayType) desc = newDescriptor("getNextBlockValidators", smartcontract.ArrayType)
md = newMethodAndPrice(n.getNextBlockValidators, 1, smartcontract.NoneFlag) md = newMethodAndPrice(n.getNextBlockValidators, 1, smartcontract.AllowStates)
n.AddMethod(md, desc, true) n.AddMethod(md, desc, true)
return n return n

View file

@ -61,12 +61,12 @@ func newNEP5Native(name string) *nep5TokenNative {
n.AddMethod(md, desc, true) n.AddMethod(md, desc, true)
desc = newDescriptor("totalSupply", smartcontract.IntegerType) desc = newDescriptor("totalSupply", smartcontract.IntegerType)
md = newMethodAndPrice(n.TotalSupply, 1, smartcontract.NoneFlag) md = newMethodAndPrice(n.TotalSupply, 1, smartcontract.AllowStates)
n.AddMethod(md, desc, true) n.AddMethod(md, desc, true)
desc = newDescriptor("balanceOf", smartcontract.IntegerType, desc = newDescriptor("balanceOf", smartcontract.IntegerType,
manifest.NewParameter("account", smartcontract.Hash160Type)) manifest.NewParameter("account", smartcontract.Hash160Type))
md = newMethodAndPrice(n.balanceOf, 1, smartcontract.NoneFlag) md = newMethodAndPrice(n.balanceOf, 1, smartcontract.AllowStates)
n.AddMethod(md, desc, true) n.AddMethod(md, desc, true)
desc = newDescriptor("transfer", smartcontract.BoolType, desc = newDescriptor("transfer", smartcontract.BoolType,
@ -74,7 +74,7 @@ func newNEP5Native(name string) *nep5TokenNative {
manifest.NewParameter("to", smartcontract.Hash160Type), manifest.NewParameter("to", smartcontract.Hash160Type),
manifest.NewParameter("amount", smartcontract.IntegerType), manifest.NewParameter("amount", smartcontract.IntegerType),
) )
md = newMethodAndPrice(n.Transfer, 1, smartcontract.NoneFlag) md = newMethodAndPrice(n.Transfer, 1, smartcontract.AllowModifyStates)
n.AddMethod(md, desc, false) n.AddMethod(md, desc, false)
n.AddEvent("Transfer", desc.Parameters...) n.AddEvent("Transfer", desc.Parameters...)

View file

@ -915,7 +915,7 @@ func (s *Server) invokescript(reqParams request.Params) (interface{}, *response.
func (s *Server) runScriptInVM(script []byte) *result.Invoke { func (s *Server) runScriptInVM(script []byte) *result.Invoke {
vm := s.chain.GetTestVM() vm := s.chain.GetTestVM()
vm.SetGasLimit(s.config.MaxGasInvoke) vm.SetGasLimit(s.config.MaxGasInvoke)
vm.LoadScript(script) vm.LoadScriptWithFlags(script, smartcontract.All)
_ = vm.Run() _ = vm.Run()
result := &result.Invoke{ result := &result.Invoke{
State: vm.State(), State: vm.State(),

View file

@ -5,10 +5,16 @@ type CallFlag byte
// Default flags. // Default flags.
const ( const (
NoneFlag CallFlag = 0 AllowStates CallFlag = 1 << iota
AllowModifyStates CallFlag = 1 << iota AllowModifyStates
AllowCall AllowCall
AllowNotify AllowNotify
ReadOnly = AllowCall | AllowNotify ReadOnly = AllowStates | AllowCall | AllowNotify
All = AllowModifyStates | AllowCall | AllowNotify All = ReadOnly | AllowModifyStates
NoneFlag CallFlag = 0
) )
// Has returns true iff all bits set in cf are also set in f.
func (f CallFlag) Has(cf CallFlag) bool {
return f&cf == cf
}

View file

@ -0,0 +1,14 @@
package smartcontract
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestCallFlag_Has(t *testing.T) {
require.True(t, AllowCall.Has(AllowCall))
require.True(t, (AllowCall | AllowNotify).Has(AllowCall))
require.False(t, (AllowCall).Has(AllowCall|AllowNotify))
require.True(t, All.Has(ReadOnly))
}

View file

@ -1,41 +1,29 @@
package trigger package trigger
//go:generate stringer -type=Type //go:generate stringer -type=Type -output=trigger_type_string.go
// Type represents trigger type used in C# reference node: https://github.com/neo-project/neo/blob/c64748ecbac3baeb8045b16af0d518398a6ced24/neo/SmartContract/TriggerType.cs#L3 // Type represents trigger type used in C# reference node: https://github.com/neo-project/neo/blob/c64748ecbac3baeb8045b16af0d518398a6ced24/neo/SmartContract/TriggerType.cs#L3
type Type byte type Type byte
// Viable list of supported trigger type constants. // Viable list of supported trigger type constants.
const ( const (
// System is trigger type that indicates that script is being invoke internally by the system.
System Type = 0x01
// The verification trigger indicates that the contract is being invoked as a verification function. // The verification trigger indicates that the contract is being invoked as a verification function.
// The verification function can accept multiple parameters, and should return a boolean value that indicates the validity of the transaction or block. // The verification function can accept multiple parameters, and should return a boolean value that indicates the validity of the transaction or block.
// The entry point of the contract will be invoked if the contract is triggered by Verification: // The entry point of the contract will be invoked if the contract is triggered by Verification:
// main(...); // main(...);
// The entry point of the contract must be able to handle this type of invocation. // The entry point of the contract must be able to handle this type of invocation.
Verification Type = 0x00 Verification Type = 0x20
// The verificationR trigger indicates that the contract is being invoked as a verification function because it is specified as a target of an output of the transaction.
// The verification function accepts no parameter, and should return a boolean value that indicates the validity of the transaction.
// The entry point of the contract will be invoked if the contract is triggered by VerificationR:
// main("receiving", new object[0]);
// The receiving function should have the following signature:
// public bool receiving()
// The receiving function will be invoked automatically when a contract is receiving assets from a transfer.
VerificationR Type = 0x01
// The application trigger indicates that the contract is being invoked as an application function. // The application trigger indicates that the contract is being invoked as an application function.
// The application function can accept multiple parameters, change the states of the blockchain, and return any type of value. // The application function can accept multiple parameters, change the states of the blockchain, and return any type of value.
// The contract can have any form of entry point, but we recommend that all contracts should have the following entry point: // The contract can have any form of entry point, but we recommend that all contracts should have the following entry point:
// public byte[] main(string operation, params object[] args) // public byte[] main(string operation, params object[] args)
// The functions can be invoked by creating an InvocationTransaction. // The functions can be invoked by creating an InvocationTransaction.
Application Type = 0x10 Application Type = 0x40
// The ApplicationR trigger indicates that the default function received of the contract is being invoked because it is specified as a target of an output of the transaction. // All represents any trigger type.
// The received function accepts no parameter, changes the states of the blockchain, and returns any type of value. All = System | Verification | Application
// The entry point of the contract will be invoked if the contract is triggered by ApplicationR:
// main("received", new object[0]);
// The received function should have the following signature:
// public byte[] received()
// The received function will be invoked automatically when a contract is receiving assets from a transfer.
ApplicationR Type = 0x11
) )

View file

@ -1,4 +1,4 @@
// Code generated by "stringer -type=Type"; DO NOT EDIT. // Code generated by "stringer -type=Type -output=trigger_type_string.go"; DO NOT EDIT.
package trigger package trigger
@ -8,29 +8,25 @@ func _() {
// An "invalid array index" compiler error signifies that the constant values have changed. // An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again. // Re-run the stringer command to generate them again.
var x [1]struct{} var x [1]struct{}
_ = x[Verification-0] _ = x[System-1]
_ = x[VerificationR-1] _ = x[Verification-32]
_ = x[Application-16] _ = x[Application-64]
_ = x[ApplicationR-17]
} }
const ( const (
_Type_name_0 = "VerificationVerificationR" _Type_name_0 = "System"
_Type_name_1 = "ApplicationApplicationR" _Type_name_1 = "Verification"
) _Type_name_2 = "Application"
var (
_Type_index_0 = [...]uint8{0, 12, 25}
_Type_index_1 = [...]uint8{0, 11, 23}
) )
func (i Type) String() string { func (i Type) String() string {
switch { switch {
case i <= 1: case i == 1:
return _Type_name_0[_Type_index_0[i]:_Type_index_0[i+1]] return _Type_name_0
case 16 <= i && i <= 17: case i == 32:
i -= 16 return _Type_name_1
return _Type_name_1[_Type_index_1[i]:_Type_index_1[i+1]] case i == 64:
return _Type_name_2
default: default:
return "Type(" + strconv.FormatInt(int64(i), 10) + ")" return "Type(" + strconv.FormatInt(int64(i), 10) + ")"
} }

View file

@ -8,10 +8,9 @@ import (
func TestStringer(t *testing.T) { func TestStringer(t *testing.T) {
tests := map[Type]string{ tests := map[Type]string{
Application: "Application", System: "System",
ApplicationR: "ApplicationR", Application: "Application",
Verification: "Verification", Verification: "Verification",
VerificationR: "VerificationR",
} }
for o, s := range tests { for o, s := range tests {
assert.Equal(t, s, o.String()) assert.Equal(t, s, o.String())
@ -20,10 +19,9 @@ func TestStringer(t *testing.T) {
func TestEncodeBynary(t *testing.T) { func TestEncodeBynary(t *testing.T) {
tests := map[Type]byte{ tests := map[Type]byte{
Verification: 0x00, System: 0x01,
VerificationR: 0x01, Verification: 0x20,
Application: 0x10, Application: 0x40,
ApplicationR: 0x11,
} }
for o, b := range tests { for o, b := range tests {
assert.Equal(t, b, byte(o)) assert.Equal(t, b, byte(o))
@ -32,10 +30,9 @@ func TestEncodeBynary(t *testing.T) {
func TestDecodeBynary(t *testing.T) { func TestDecodeBynary(t *testing.T) {
tests := map[Type]byte{ tests := map[Type]byte{
Verification: 0x00, System: 0x01,
VerificationR: 0x01, Verification: 0x20,
Application: 0x10, Application: 0x40,
ApplicationR: 0x11,
} }
for o, b := range tests { for o, b := range tests {
assert.Equal(t, o, Type(b)) assert.Equal(t, o, Type(b))

View file

@ -6,6 +6,7 @@ import (
"math/big" "math/big"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -39,6 +40,9 @@ type Context struct {
// Script hash of the prog. // Script hash of the prog.
scriptHash util.Uint160 scriptHash util.Uint160
// Call flags this context was created with.
callFlag smartcontract.CallFlag
} }
var errNoInstParam = errors.New("failed to read instruction parameter") var errNoInstParam = errors.New("failed to read instruction parameter")
@ -154,6 +158,11 @@ func (c *Context) Copy() *Context {
return ctx return ctx
} }
// GetCallFlags returns calling flags context was created with.
func (c *Context) GetCallFlags() smartcontract.CallFlag {
return c.callFlag
}
// Program returns the loaded program. // Program returns the loaded program.
func (c *Context) Program() []byte { func (c *Context) Program() []byte {
return c.prog return c.prog

View file

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"sort" "sort"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
@ -16,6 +18,10 @@ type InteropFunc func(vm *VM) error
type InteropFuncPrice struct { type InteropFuncPrice struct {
Func InteropFunc Func InteropFunc
Price int Price int
// AllowedTriggers is a mask representing triggers which should be allowed by an interop.
// 0 is interpreted as All.
AllowedTriggers trigger.Type
RequiredFlags smartcontract.CallFlag
} }
// interopIDFuncPrice adds an ID to the InteropFuncPrice. // interopIDFuncPrice adds an ID to the InteropFuncPrice.
@ -30,31 +36,31 @@ type InteropGetterFunc func(uint32) *InteropFuncPrice
var defaultVMInterops = []interopIDFuncPrice{ var defaultVMInterops = []interopIDFuncPrice{
{emit.InteropNameToID([]byte("System.Runtime.Log")), {emit.InteropNameToID([]byte("System.Runtime.Log")),
InteropFuncPrice{runtimeLog, 1}}, InteropFuncPrice{Func: runtimeLog, Price: 1}},
{emit.InteropNameToID([]byte("System.Runtime.Notify")), {emit.InteropNameToID([]byte("System.Runtime.Notify")),
InteropFuncPrice{runtimeNotify, 1}}, InteropFuncPrice{Func: runtimeNotify, Price: 1}},
{emit.InteropNameToID([]byte("System.Runtime.Serialize")), {emit.InteropNameToID([]byte("System.Runtime.Serialize")),
InteropFuncPrice{RuntimeSerialize, 1}}, InteropFuncPrice{Func: RuntimeSerialize, Price: 1}},
{emit.InteropNameToID([]byte("System.Runtime.Deserialize")), {emit.InteropNameToID([]byte("System.Runtime.Deserialize")),
InteropFuncPrice{RuntimeDeserialize, 1}}, InteropFuncPrice{Func: RuntimeDeserialize, Price: 1}},
{emit.InteropNameToID([]byte("System.Enumerator.Create")), {emit.InteropNameToID([]byte("System.Enumerator.Create")),
InteropFuncPrice{EnumeratorCreate, 1}}, InteropFuncPrice{Func: EnumeratorCreate, Price: 1}},
{emit.InteropNameToID([]byte("System.Enumerator.Next")), {emit.InteropNameToID([]byte("System.Enumerator.Next")),
InteropFuncPrice{EnumeratorNext, 1}}, InteropFuncPrice{Func: EnumeratorNext, Price: 1}},
{emit.InteropNameToID([]byte("System.Enumerator.Concat")), {emit.InteropNameToID([]byte("System.Enumerator.Concat")),
InteropFuncPrice{EnumeratorConcat, 1}}, InteropFuncPrice{Func: EnumeratorConcat, Price: 1}},
{emit.InteropNameToID([]byte("System.Enumerator.Value")), {emit.InteropNameToID([]byte("System.Enumerator.Value")),
InteropFuncPrice{EnumeratorValue, 1}}, InteropFuncPrice{Func: EnumeratorValue, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Create")), {emit.InteropNameToID([]byte("System.Iterator.Create")),
InteropFuncPrice{IteratorCreate, 1}}, InteropFuncPrice{Func: IteratorCreate, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Concat")), {emit.InteropNameToID([]byte("System.Iterator.Concat")),
InteropFuncPrice{IteratorConcat, 1}}, InteropFuncPrice{Func: IteratorConcat, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Key")), {emit.InteropNameToID([]byte("System.Iterator.Key")),
InteropFuncPrice{IteratorKey, 1}}, InteropFuncPrice{Func: IteratorKey, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Keys")), {emit.InteropNameToID([]byte("System.Iterator.Keys")),
InteropFuncPrice{IteratorKeys, 1}}, InteropFuncPrice{Func: IteratorKeys, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Values")), {emit.InteropNameToID([]byte("System.Iterator.Values")),
InteropFuncPrice{IteratorValues, 1}}, InteropFuncPrice{Func: IteratorValues, Price: 1}},
} }
func getDefaultVMInterop(id uint32) *InteropFuncPrice { func getDefaultVMInterop(id uint32) *InteropFuncPrice {

View file

@ -17,6 +17,8 @@ import (
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -111,11 +113,23 @@ func TestUT(t *testing.T) {
} }
func getTestingInterop(id uint32) *InteropFuncPrice { func getTestingInterop(id uint32) *InteropFuncPrice {
if id == binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}) { f := func(v *VM) error {
return &InteropFuncPrice{InteropFunc(func(v *VM) error { v.estack.PushVal(stackitem.NewInterop(new(int)))
v.estack.PushVal(stackitem.NewInterop(new(int))) return nil
return nil }
}), 0} switch id {
case binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}):
return &InteropFuncPrice{Func: f}
case binary.LittleEndian.Uint32([]byte{0x66, 0x66, 0x66, 0x66}):
return &InteropFuncPrice{
Func: f,
RequiredFlags: smartcontract.ReadOnly,
}
case binary.LittleEndian.Uint32([]byte{0x55, 0x55, 0x55, 0x55}):
return &InteropFuncPrice{
Func: f,
AllowedTriggers: trigger.Application,
}
} }
return nil return nil
} }

View file

@ -13,6 +13,8 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -84,18 +86,26 @@ type VM struct {
gasConsumed util.Fixed8 gasConsumed util.Fixed8
gasLimit util.Fixed8 gasLimit util.Fixed8
trigger trigger.Type
// Public keys cache. // Public keys cache.
keys map[string]*keys.PublicKey keys map[string]*keys.PublicKey
} }
// New returns a new VM object ready to load .avm bytecode scripts. // New returns a new VM object ready to load .avm bytecode scripts.
func New() *VM { func New() *VM {
return NewWithTrigger(trigger.System)
}
// NewWithTrigger returns a new VM for executions triggered by t.
func NewWithTrigger(t trigger.Type) *VM {
vm := &VM{ vm := &VM{
getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage. getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage.
state: haltState, state: haltState,
istack: NewStack("invocation"), istack: NewStack("invocation"),
refs: newRefCounter(), refs: newRefCounter(),
keys: make(map[string]*keys.PublicKey), keys: make(map[string]*keys.PublicKey),
trigger: t,
} }
vm.estack = vm.newItemStack("evaluation") vm.estack = vm.newItemStack("evaluation")
@ -262,12 +272,27 @@ func (v *VM) Load(prog []byte) {
// will immediately push a new context created from this script to // will immediately push a new context created from this script to
// the invocation stack and starts executing it. // the invocation stack and starts executing it.
func (v *VM) LoadScript(b []byte) { func (v *VM) LoadScript(b []byte) {
v.LoadScriptWithFlags(b, smartcontract.NoneFlag)
}
// LoadScriptWithFlags loads script and sets call flag to f.
func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) {
ctx := NewContext(b) ctx := NewContext(b)
ctx.estack = v.estack ctx.estack = v.estack
ctx.astack = v.astack ctx.astack = v.astack
ctx.callFlag = f
v.istack.PushVal(ctx) v.istack.PushVal(ctx)
} }
// LoadScriptWithHash if similar to the LoadScriptWithFlags method, but it also loads
// given script hash directly into the Context to avoid its recalculations. It's
// up to user of this function to make sure the script and hash match each other.
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) {
v.LoadScriptWithFlags(b, f)
ctx := v.Context()
ctx.scriptHash = hash
}
// Context returns the current executed context. Nil if there is no context, // Context returns the current executed context. Nil if there is no context,
// which implies no program is loaded. // which implies no program is loaded.
func (v *VM) Context() *Context { func (v *VM) Context() *Context {
@ -1244,6 +1269,12 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.SYSCALL: case opcode.SYSCALL:
interopID := GetInteropID(parameter) interopID := GetInteropID(parameter)
ifunc := v.GetInteropByID(interopID) ifunc := v.GetInteropByID(interopID)
if ifunc.AllowedTriggers != 0 && ifunc.AllowedTriggers&v.trigger == 0 {
panic(fmt.Sprintf("trigger not allowed: %s", v.trigger))
}
if !v.Context().callFlag.Has(ifunc.RequiredFlags) {
panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags))
}
if ifunc == nil { if ifunc == nil {
panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID)) panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID))

View file

@ -12,6 +12,8 @@ import (
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
@ -22,10 +24,13 @@ import (
func fooInteropGetter(id uint32) *InteropFuncPrice { func fooInteropGetter(id uint32) *InteropFuncPrice {
if id == emit.InteropNameToID([]byte("foo")) { if id == emit.InteropNameToID([]byte("foo")) {
return &InteropFuncPrice{func(evm *VM) error { return &InteropFuncPrice{
evm.Estack().PushVal(1) Func: func(evm *VM) error {
return nil evm.Estack().PushVal(1)
}, 1} return nil
},
Price: 1,
}
} }
return nil return nil
} }
@ -812,6 +817,54 @@ func TestSerializeInterop(t *testing.T) {
require.True(t, vm.HasFailed()) require.True(t, vm.HasFailed())
} }
func getTestCallFlagsFunc(syscall []byte, flags smartcontract.CallFlag, result interface{}) func(t *testing.T) {
return func(t *testing.T) {
script := append([]byte{byte(opcode.SYSCALL)}, syscall...)
v := New()
v.RegisterInteropGetter(getTestingInterop)
v.LoadScriptWithFlags(script, flags)
if result == nil {
checkVMFailed(t, v)
return
}
runVM(t, v)
require.Equal(t, result, v.PopResult())
}
}
func TestCallFlags(t *testing.T) {
noFlags := []byte{0x77, 0x77, 0x77, 0x77}
readOnly := []byte{0x66, 0x66, 0x66, 0x66}
t.Run("NoFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.NoneFlag, new(int)))
t.Run("ProvideFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.AllowCall, new(int)))
t.Run("NoFlagsSomeRequired", getTestCallFlagsFunc(readOnly, smartcontract.NoneFlag, nil))
t.Run("OnlyOneProvided", getTestCallFlagsFunc(readOnly, smartcontract.AllowCall, nil))
t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int)))
}
func getTestTriggerFunc(syscall []byte, tr trigger.Type, result interface{}) func(t *testing.T) {
return func(t *testing.T) {
script := append([]byte{byte(opcode.SYSCALL)}, syscall...)
v := NewWithTrigger(tr)
v.RegisterInteropGetter(getTestingInterop)
v.LoadScript(script)
if result == nil {
checkVMFailed(t, v)
return
}
runVM(t, v)
require.Equal(t, result, v.PopResult())
}
}
func TestAllowedTriggers(t *testing.T) {
noFlags := []byte{0x77, 0x77, 0x77, 0x77}
appOnly := []byte{0x55, 0x55, 0x55, 0x55}
t.Run("Application/NeedNothing", getTestTriggerFunc(noFlags, trigger.Application, new(int)))
t.Run("Application/NeedApplication", getTestTriggerFunc(appOnly, trigger.Application, new(int)))
t.Run("System/NeedApplication", getTestTriggerFunc(appOnly, trigger.System, nil))
}
func callNTimes(n uint16) []byte { func callNTimes(n uint16) []byte {
return makeProgram( return makeProgram(
opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian