vm: provide trigger upon VM creation
This commit is contained in:
parent
b12add5a78
commit
4dfce07d11
4 changed files with 49 additions and 3 deletions
|
@ -6,6 +6,7 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
)
|
)
|
||||||
|
@ -17,6 +18,9 @@ type InteropFunc func(vm *VM) error
|
||||||
type InteropFuncPrice struct {
|
type InteropFuncPrice struct {
|
||||||
Func InteropFunc
|
Func InteropFunc
|
||||||
Price int
|
Price int
|
||||||
|
// AllowedTriggers is a mask representing triggers which should be allowed by an interop.
|
||||||
|
// 0 is interpreted as All.
|
||||||
|
AllowedTriggers trigger.Type
|
||||||
RequiredFlags smartcontract.CallFlag
|
RequiredFlags smartcontract.CallFlag
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -124,6 +125,11 @@ func getTestingInterop(id uint32) *InteropFuncPrice {
|
||||||
Func: f,
|
Func: f,
|
||||||
RequiredFlags: smartcontract.ReadOnly,
|
RequiredFlags: smartcontract.ReadOnly,
|
||||||
}
|
}
|
||||||
|
case binary.LittleEndian.Uint32([]byte{0x55, 0x55, 0x55, 0x55}):
|
||||||
|
return &InteropFuncPrice{
|
||||||
|
Func: f,
|
||||||
|
AllowedTriggers: trigger.Application,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
12
pkg/vm/vm.go
12
pkg/vm/vm.go
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
|
@ -85,18 +86,26 @@ type VM struct {
|
||||||
gasConsumed util.Fixed8
|
gasConsumed util.Fixed8
|
||||||
gasLimit util.Fixed8
|
gasLimit util.Fixed8
|
||||||
|
|
||||||
|
trigger trigger.Type
|
||||||
|
|
||||||
// Public keys cache.
|
// Public keys cache.
|
||||||
keys map[string]*keys.PublicKey
|
keys map[string]*keys.PublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new VM object ready to load .avm bytecode scripts.
|
// New returns a new VM object ready to load .avm bytecode scripts.
|
||||||
func New() *VM {
|
func New() *VM {
|
||||||
|
return NewWithTrigger(trigger.System)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWithTrigger returns a new VM for executions triggered by t.
|
||||||
|
func NewWithTrigger(t trigger.Type) *VM {
|
||||||
vm := &VM{
|
vm := &VM{
|
||||||
getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage.
|
getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage.
|
||||||
state: haltState,
|
state: haltState,
|
||||||
istack: NewStack("invocation"),
|
istack: NewStack("invocation"),
|
||||||
refs: newRefCounter(),
|
refs: newRefCounter(),
|
||||||
keys: make(map[string]*keys.PublicKey),
|
keys: make(map[string]*keys.PublicKey),
|
||||||
|
trigger: t,
|
||||||
}
|
}
|
||||||
|
|
||||||
vm.estack = vm.newItemStack("evaluation")
|
vm.estack = vm.newItemStack("evaluation")
|
||||||
|
@ -1260,6 +1269,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
||||||
case opcode.SYSCALL:
|
case opcode.SYSCALL:
|
||||||
interopID := GetInteropID(parameter)
|
interopID := GetInteropID(parameter)
|
||||||
ifunc := v.GetInteropByID(interopID)
|
ifunc := v.GetInteropByID(interopID)
|
||||||
|
if ifunc.AllowedTriggers != 0 && ifunc.AllowedTriggers&v.trigger == 0 {
|
||||||
|
panic(fmt.Sprintf("trigger not allowed: %s", v.trigger))
|
||||||
|
}
|
||||||
if !v.Context().callFlag.Has(ifunc.RequiredFlags) {
|
if !v.Context().callFlag.Has(ifunc.RequiredFlags) {
|
||||||
panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags))
|
panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags))
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
|
@ -841,6 +842,29 @@ func TestCallFlags(t *testing.T) {
|
||||||
t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int)))
|
t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTestTriggerFunc(syscall []byte, tr trigger.Type, result interface{}) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
script := append([]byte{byte(opcode.SYSCALL)}, syscall...)
|
||||||
|
v := NewWithTrigger(tr)
|
||||||
|
v.RegisterInteropGetter(getTestingInterop)
|
||||||
|
v.LoadScript(script)
|
||||||
|
if result == nil {
|
||||||
|
checkVMFailed(t, v)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runVM(t, v)
|
||||||
|
require.Equal(t, result, v.PopResult())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllowedTriggers(t *testing.T) {
|
||||||
|
noFlags := []byte{0x77, 0x77, 0x77, 0x77}
|
||||||
|
appOnly := []byte{0x55, 0x55, 0x55, 0x55}
|
||||||
|
t.Run("Application/NeedNothing", getTestTriggerFunc(noFlags, trigger.Application, new(int)))
|
||||||
|
t.Run("Application/NeedApplication", getTestTriggerFunc(appOnly, trigger.Application, new(int)))
|
||||||
|
t.Run("System/NeedApplication", getTestTriggerFunc(appOnly, trigger.System, nil))
|
||||||
|
}
|
||||||
|
|
||||||
func callNTimes(n uint16) []byte {
|
func callNTimes(n uint16) []byte {
|
||||||
return makeProgram(
|
return makeProgram(
|
||||||
opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian
|
opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian
|
||||||
|
|
Loading…
Reference in a new issue