diff --git a/pkg/compiler/interop_test.go b/pkg/compiler/interop_test.go index a56ea338b..7487b1493 100644 --- a/pkg/compiler/interop_test.go +++ b/pkg/compiler/interop_test.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/encoding/address" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -64,7 +65,7 @@ func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM { b, err := compiler.Compile(strings.NewReader(src)) require.NoError(t, err) v := core.SpawnVM(ic) - v.Load(b) + v.LoadScriptWithFlags(b, smartcontract.All) return v } diff --git a/pkg/core/interop_system.go b/pkg/core/interop_system.go index a42dcde98..51daacde2 100644 --- a/pkg/core/interop_system.go +++ b/pkg/core/interop_system.go @@ -423,7 +423,7 @@ func contractCallEx(ic *interop.Context, v *vm.VM) error { return contractCallExInternal(ic, v, h, method, args, flags) } -func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, _ smartcontract.CallFlag) error { +func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, f smartcontract.CallFlag) error { u, err := util.Uint160DecodeBytesBE(h) if err != nil { return errors.New("invalid contract hash") @@ -442,7 +442,7 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac return errors.New("disallowed method call") } } - v.LoadScriptWithHash(cs.Script, u) + v.LoadScriptWithHash(cs.Script, u, f) v.Estack().PushVal(args) v.Estack().PushVal(method) return nil diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 83e385d19..f984ce1e5 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -6,6 +6,7 @@ import ( "math/big" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -39,6 +40,9 @@ type Context struct { // Script hash of the prog. scriptHash util.Uint160 + + // Call flags this context was created with. + callFlag smartcontract.CallFlag } var errNoInstParam = errors.New("failed to read instruction parameter") diff --git a/pkg/vm/interop.go b/pkg/vm/interop.go index 4024a09f4..90175b4c8 100644 --- a/pkg/vm/interop.go +++ b/pkg/vm/interop.go @@ -5,6 +5,7 @@ import ( "fmt" "sort" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -14,8 +15,9 @@ type InteropFunc func(vm *VM) error // InteropFuncPrice represents an interop function with a price. type InteropFuncPrice struct { - Func InteropFunc - Price int + Func InteropFunc + Price int + RequiredFlags smartcontract.CallFlag } // interopIDFuncPrice adds an ID to the InteropFuncPrice. @@ -30,31 +32,31 @@ type InteropGetterFunc func(uint32) *InteropFuncPrice var defaultVMInterops = []interopIDFuncPrice{ {emit.InteropNameToID([]byte("System.Runtime.Log")), - InteropFuncPrice{runtimeLog, 1}}, + InteropFuncPrice{Func: runtimeLog, Price: 1}}, {emit.InteropNameToID([]byte("System.Runtime.Notify")), - InteropFuncPrice{runtimeNotify, 1}}, + InteropFuncPrice{Func: runtimeNotify, Price: 1}}, {emit.InteropNameToID([]byte("System.Runtime.Serialize")), - InteropFuncPrice{RuntimeSerialize, 1}}, + InteropFuncPrice{Func: RuntimeSerialize, Price: 1}}, {emit.InteropNameToID([]byte("System.Runtime.Deserialize")), - InteropFuncPrice{RuntimeDeserialize, 1}}, + InteropFuncPrice{Func: RuntimeDeserialize, Price: 1}}, {emit.InteropNameToID([]byte("System.Enumerator.Create")), - InteropFuncPrice{EnumeratorCreate, 1}}, + InteropFuncPrice{Func: EnumeratorCreate, Price: 1}}, {emit.InteropNameToID([]byte("System.Enumerator.Next")), - InteropFuncPrice{EnumeratorNext, 1}}, + InteropFuncPrice{Func: EnumeratorNext, Price: 1}}, {emit.InteropNameToID([]byte("System.Enumerator.Concat")), - InteropFuncPrice{EnumeratorConcat, 1}}, + InteropFuncPrice{Func: EnumeratorConcat, Price: 1}}, {emit.InteropNameToID([]byte("System.Enumerator.Value")), - InteropFuncPrice{EnumeratorValue, 1}}, + InteropFuncPrice{Func: EnumeratorValue, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Create")), - InteropFuncPrice{IteratorCreate, 1}}, + InteropFuncPrice{Func: IteratorCreate, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Concat")), - InteropFuncPrice{IteratorConcat, 1}}, + InteropFuncPrice{Func: IteratorConcat, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Key")), - InteropFuncPrice{IteratorKey, 1}}, + InteropFuncPrice{Func: IteratorKey, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Keys")), - InteropFuncPrice{IteratorKeys, 1}}, + InteropFuncPrice{Func: IteratorKeys, Price: 1}}, {emit.InteropNameToID([]byte("System.Iterator.Values")), - InteropFuncPrice{IteratorValues, 1}}, + InteropFuncPrice{Func: IteratorValues, Price: 1}}, } func getDefaultVMInterop(id uint32) *InteropFuncPrice { diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index ee2823d49..a758b24a0 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -17,6 +17,7 @@ import ( "testing" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/require" @@ -111,11 +112,18 @@ func TestUT(t *testing.T) { } func getTestingInterop(id uint32) *InteropFuncPrice { - if id == binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}) { - return &InteropFuncPrice{InteropFunc(func(v *VM) error { - v.estack.PushVal(stackitem.NewInterop(new(int))) - return nil - }), 0} + f := func(v *VM) error { + v.estack.PushVal(stackitem.NewInterop(new(int))) + return nil + } + switch id { + case binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}): + return &InteropFuncPrice{Func: f} + case binary.LittleEndian.Uint32([]byte{0x66, 0x66, 0x66, 0x66}): + return &InteropFuncPrice{ + Func: f, + RequiredFlags: smartcontract.ReadOnly, + } } return nil } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 1163029f8..47a7a23d2 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -262,17 +263,23 @@ func (v *VM) Load(prog []byte) { // will immediately push a new context created from this script to // the invocation stack and starts executing it. func (v *VM) LoadScript(b []byte) { + v.LoadScriptWithFlags(b, smartcontract.NoneFlag) +} + +// LoadScriptWithFlags loads script and sets call flag to f. +func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) { ctx := NewContext(b) ctx.estack = v.estack ctx.astack = v.astack + ctx.callFlag = f v.istack.PushVal(ctx) } -// LoadScriptWithHash if similar to the LoadScript method, but it also loads +// LoadScriptWithHash if similar to the LoadScriptWithFlags method, but it also loads // given script hash directly into the Context to avoid its recalculations. It's // up to user of this function to make sure the script and hash match each other. -func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160) { - v.LoadScript(b) +func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) { + v.LoadScriptWithFlags(b, f) ctx := v.Context() ctx.scriptHash = hash } @@ -1253,6 +1260,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.SYSCALL: interopID := GetInteropID(parameter) ifunc := v.GetInteropByID(interopID) + if !v.Context().callFlag.Has(ifunc.RequiredFlags) { + panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags)) + } if ifunc == nil { panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID)) diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index ae9a06fd0..94911dc6c 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -12,6 +12,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -22,10 +23,13 @@ import ( func fooInteropGetter(id uint32) *InteropFuncPrice { if id == emit.InteropNameToID([]byte("foo")) { - return &InteropFuncPrice{func(evm *VM) error { - evm.Estack().PushVal(1) - return nil - }, 1} + return &InteropFuncPrice{ + Func: func(evm *VM) error { + evm.Estack().PushVal(1) + return nil + }, + Price: 1, + } } return nil } @@ -812,6 +816,31 @@ func TestSerializeInterop(t *testing.T) { require.True(t, vm.HasFailed()) } +func getTestCallFlagsFunc(syscall []byte, flags smartcontract.CallFlag, result interface{}) func(t *testing.T) { + return func(t *testing.T) { + script := append([]byte{byte(opcode.SYSCALL)}, syscall...) + v := New() + v.RegisterInteropGetter(getTestingInterop) + v.LoadScriptWithFlags(script, flags) + if result == nil { + checkVMFailed(t, v) + return + } + runVM(t, v) + require.Equal(t, result, v.PopResult()) + } +} + +func TestCallFlags(t *testing.T) { + noFlags := []byte{0x77, 0x77, 0x77, 0x77} + readOnly := []byte{0x66, 0x66, 0x66, 0x66} + t.Run("NoFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.NoneFlag, new(int))) + t.Run("ProvideFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.AllowCall, new(int))) + t.Run("NoFlagsSomeRequired", getTestCallFlagsFunc(readOnly, smartcontract.NoneFlag, nil)) + t.Run("OnlyOneProvided", getTestCallFlagsFunc(readOnly, smartcontract.AllowCall, nil)) + t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int))) +} + func callNTimes(n uint16) []byte { return makeProgram( opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian