diff --git a/pkg/vm/interop.go b/pkg/vm/interop.go index 90175b4c8..2d8fb848c 100644 --- a/pkg/vm/interop.go +++ b/pkg/vm/interop.go @@ -6,6 +6,7 @@ import ( "sort" "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -15,9 +16,12 @@ type InteropFunc func(vm *VM) error // InteropFuncPrice represents an interop function with a price. type InteropFuncPrice struct { - Func InteropFunc - Price int - RequiredFlags smartcontract.CallFlag + Func InteropFunc + Price int + // AllowedTriggers is a mask representing triggers which should be allowed by an interop. + // 0 is interpreted as All. + AllowedTriggers trigger.Type + RequiredFlags smartcontract.CallFlag } // interopIDFuncPrice adds an ID to the InteropFuncPrice. diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index a758b24a0..614aec000 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -18,6 +18,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/require" @@ -124,6 +125,11 @@ func getTestingInterop(id uint32) *InteropFuncPrice { Func: f, RequiredFlags: smartcontract.ReadOnly, } + case binary.LittleEndian.Uint32([]byte{0x55, 0x55, 0x55, 0x55}): + return &InteropFuncPrice{ + Func: f, + AllowedTriggers: trigger.Application, + } } return nil } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 47a7a23d2..9cb456583 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -85,18 +86,26 @@ type VM struct { gasConsumed util.Fixed8 gasLimit util.Fixed8 + trigger trigger.Type + // Public keys cache. keys map[string]*keys.PublicKey } // New returns a new VM object ready to load .avm bytecode scripts. func New() *VM { + return NewWithTrigger(trigger.System) +} + +// NewWithTrigger returns a new VM for executions triggered by t. +func NewWithTrigger(t trigger.Type) *VM { vm := &VM{ getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage. state: haltState, istack: NewStack("invocation"), refs: newRefCounter(), keys: make(map[string]*keys.PublicKey), + trigger: t, } vm.estack = vm.newItemStack("evaluation") @@ -1260,6 +1269,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.SYSCALL: interopID := GetInteropID(parameter) ifunc := v.GetInteropByID(interopID) + if ifunc.AllowedTriggers != 0 && ifunc.AllowedTriggers&v.trigger == 0 { + panic(fmt.Sprintf("trigger not allowed: %s", v.trigger)) + } if !v.Context().callFlag.Has(ifunc.RequiredFlags) { panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags)) } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 94911dc6c..17ba7c72d 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -841,6 +842,29 @@ func TestCallFlags(t *testing.T) { t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int))) } +func getTestTriggerFunc(syscall []byte, tr trigger.Type, result interface{}) func(t *testing.T) { + return func(t *testing.T) { + script := append([]byte{byte(opcode.SYSCALL)}, syscall...) + v := NewWithTrigger(tr) + v.RegisterInteropGetter(getTestingInterop) + v.LoadScript(script) + if result == nil { + checkVMFailed(t, v) + return + } + runVM(t, v) + require.Equal(t, result, v.PopResult()) + } +} + +func TestAllowedTriggers(t *testing.T) { + noFlags := []byte{0x77, 0x77, 0x77, 0x77} + appOnly := []byte{0x55, 0x55, 0x55, 0x55} + t.Run("Application/NeedNothing", getTestTriggerFunc(noFlags, trigger.Application, new(int))) + t.Run("Application/NeedApplication", getTestTriggerFunc(appOnly, trigger.Application, new(int))) + t.Run("System/NeedApplication", getTestTriggerFunc(appOnly, trigger.System, nil)) +} + func callNTimes(n uint16) []byte { return makeProgram( opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian