vm: provide trigger upon VM creation

This commit is contained in:
Evgenii Stratonikov 2020-06-10 17:57:10 +03:00
parent b12add5a78
commit 4dfce07d11
4 changed files with 49 additions and 3 deletions

View file

@ -6,6 +6,7 @@ import (
"sort" "sort"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
@ -17,6 +18,9 @@ type InteropFunc func(vm *VM) error
type InteropFuncPrice struct { type InteropFuncPrice struct {
Func InteropFunc Func InteropFunc
Price int Price int
// AllowedTriggers is a mask representing triggers which should be allowed by an interop.
// 0 is interpreted as All.
AllowedTriggers trigger.Type
RequiredFlags smartcontract.CallFlag RequiredFlags smartcontract.CallFlag
} }

View file

@ -18,6 +18,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -124,6 +125,11 @@ func getTestingInterop(id uint32) *InteropFuncPrice {
Func: f, Func: f,
RequiredFlags: smartcontract.ReadOnly, RequiredFlags: smartcontract.ReadOnly,
} }
case binary.LittleEndian.Uint32([]byte{0x55, 0x55, 0x55, 0x55}):
return &InteropFuncPrice{
Func: f,
AllowedTriggers: trigger.Application,
}
} }
return nil return nil
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -85,18 +86,26 @@ type VM struct {
gasConsumed util.Fixed8 gasConsumed util.Fixed8
gasLimit util.Fixed8 gasLimit util.Fixed8
trigger trigger.Type
// Public keys cache. // Public keys cache.
keys map[string]*keys.PublicKey keys map[string]*keys.PublicKey
} }
// New returns a new VM object ready to load .avm bytecode scripts. // New returns a new VM object ready to load .avm bytecode scripts.
func New() *VM { func New() *VM {
return NewWithTrigger(trigger.System)
}
// NewWithTrigger returns a new VM for executions triggered by t.
func NewWithTrigger(t trigger.Type) *VM {
vm := &VM{ vm := &VM{
getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage. getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage.
state: haltState, state: haltState,
istack: NewStack("invocation"), istack: NewStack("invocation"),
refs: newRefCounter(), refs: newRefCounter(),
keys: make(map[string]*keys.PublicKey), keys: make(map[string]*keys.PublicKey),
trigger: t,
} }
vm.estack = vm.newItemStack("evaluation") vm.estack = vm.newItemStack("evaluation")
@ -1260,6 +1269,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.SYSCALL: case opcode.SYSCALL:
interopID := GetInteropID(parameter) interopID := GetInteropID(parameter)
ifunc := v.GetInteropByID(interopID) ifunc := v.GetInteropByID(interopID)
if ifunc.AllowedTriggers != 0 && ifunc.AllowedTriggers&v.trigger == 0 {
panic(fmt.Sprintf("trigger not allowed: %s", v.trigger))
}
if !v.Context().callFlag.Has(ifunc.RequiredFlags) { if !v.Context().callFlag.Has(ifunc.RequiredFlags) {
panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags)) panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags))
} }

View file

@ -13,6 +13,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
@ -841,6 +842,29 @@ func TestCallFlags(t *testing.T) {
t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int))) t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int)))
} }
func getTestTriggerFunc(syscall []byte, tr trigger.Type, result interface{}) func(t *testing.T) {
return func(t *testing.T) {
script := append([]byte{byte(opcode.SYSCALL)}, syscall...)
v := NewWithTrigger(tr)
v.RegisterInteropGetter(getTestingInterop)
v.LoadScript(script)
if result == nil {
checkVMFailed(t, v)
return
}
runVM(t, v)
require.Equal(t, result, v.PopResult())
}
}
func TestAllowedTriggers(t *testing.T) {
noFlags := []byte{0x77, 0x77, 0x77, 0x77}
appOnly := []byte{0x55, 0x55, 0x55, 0x55}
t.Run("Application/NeedNothing", getTestTriggerFunc(noFlags, trigger.Application, new(int)))
t.Run("Application/NeedApplication", getTestTriggerFunc(appOnly, trigger.Application, new(int)))
t.Run("System/NeedApplication", getTestTriggerFunc(appOnly, trigger.System, nil))
}
func callNTimes(n uint16) []byte { func callNTimes(n uint16) []byte {
return makeProgram( return makeProgram(
opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian