vm: check calling flags on syscall invocation
This commit is contained in:
parent
55ab7535be
commit
bda94c74c3
7 changed files with 84 additions and 30 deletions
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
|
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
|
@ -64,7 +65,7 @@ func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM {
|
||||||
b, err := compiler.Compile(strings.NewReader(src))
|
b, err := compiler.Compile(strings.NewReader(src))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
v := core.SpawnVM(ic)
|
v := core.SpawnVM(ic)
|
||||||
v.Load(b)
|
v.LoadScriptWithFlags(b, smartcontract.All)
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -423,7 +423,7 @@ func contractCallEx(ic *interop.Context, v *vm.VM) error {
|
||||||
return contractCallExInternal(ic, v, h, method, args, flags)
|
return contractCallExInternal(ic, v, h, method, args, flags)
|
||||||
}
|
}
|
||||||
|
|
||||||
func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, _ smartcontract.CallFlag) error {
|
func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, f smartcontract.CallFlag) error {
|
||||||
u, err := util.Uint160DecodeBytesBE(h)
|
u, err := util.Uint160DecodeBytesBE(h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("invalid contract hash")
|
return errors.New("invalid contract hash")
|
||||||
|
@ -442,7 +442,7 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
|
||||||
return errors.New("disallowed method call")
|
return errors.New("disallowed method call")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
v.LoadScriptWithHash(cs.Script, u)
|
v.LoadScriptWithHash(cs.Script, u, f)
|
||||||
v.Estack().PushVal(args)
|
v.Estack().PushVal(args)
|
||||||
v.Estack().PushVal(method)
|
v.Estack().PushVal(method)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"math/big"
|
"math/big"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
|
@ -39,6 +40,9 @@ type Context struct {
|
||||||
|
|
||||||
// Script hash of the prog.
|
// Script hash of the prog.
|
||||||
scriptHash util.Uint160
|
scriptHash util.Uint160
|
||||||
|
|
||||||
|
// Call flags this context was created with.
|
||||||
|
callFlag smartcontract.CallFlag
|
||||||
}
|
}
|
||||||
|
|
||||||
var errNoInstParam = errors.New("failed to read instruction parameter")
|
var errNoInstParam = errors.New("failed to read instruction parameter")
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
)
|
)
|
||||||
|
@ -16,6 +17,7 @@ type InteropFunc func(vm *VM) error
|
||||||
type InteropFuncPrice struct {
|
type InteropFuncPrice struct {
|
||||||
Func InteropFunc
|
Func InteropFunc
|
||||||
Price int
|
Price int
|
||||||
|
RequiredFlags smartcontract.CallFlag
|
||||||
}
|
}
|
||||||
|
|
||||||
// interopIDFuncPrice adds an ID to the InteropFuncPrice.
|
// interopIDFuncPrice adds an ID to the InteropFuncPrice.
|
||||||
|
@ -30,31 +32,31 @@ type InteropGetterFunc func(uint32) *InteropFuncPrice
|
||||||
|
|
||||||
var defaultVMInterops = []interopIDFuncPrice{
|
var defaultVMInterops = []interopIDFuncPrice{
|
||||||
{emit.InteropNameToID([]byte("System.Runtime.Log")),
|
{emit.InteropNameToID([]byte("System.Runtime.Log")),
|
||||||
InteropFuncPrice{runtimeLog, 1}},
|
InteropFuncPrice{Func: runtimeLog, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Runtime.Notify")),
|
{emit.InteropNameToID([]byte("System.Runtime.Notify")),
|
||||||
InteropFuncPrice{runtimeNotify, 1}},
|
InteropFuncPrice{Func: runtimeNotify, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Runtime.Serialize")),
|
{emit.InteropNameToID([]byte("System.Runtime.Serialize")),
|
||||||
InteropFuncPrice{RuntimeSerialize, 1}},
|
InteropFuncPrice{Func: RuntimeSerialize, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Runtime.Deserialize")),
|
{emit.InteropNameToID([]byte("System.Runtime.Deserialize")),
|
||||||
InteropFuncPrice{RuntimeDeserialize, 1}},
|
InteropFuncPrice{Func: RuntimeDeserialize, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Enumerator.Create")),
|
{emit.InteropNameToID([]byte("System.Enumerator.Create")),
|
||||||
InteropFuncPrice{EnumeratorCreate, 1}},
|
InteropFuncPrice{Func: EnumeratorCreate, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Enumerator.Next")),
|
{emit.InteropNameToID([]byte("System.Enumerator.Next")),
|
||||||
InteropFuncPrice{EnumeratorNext, 1}},
|
InteropFuncPrice{Func: EnumeratorNext, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Enumerator.Concat")),
|
{emit.InteropNameToID([]byte("System.Enumerator.Concat")),
|
||||||
InteropFuncPrice{EnumeratorConcat, 1}},
|
InteropFuncPrice{Func: EnumeratorConcat, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Enumerator.Value")),
|
{emit.InteropNameToID([]byte("System.Enumerator.Value")),
|
||||||
InteropFuncPrice{EnumeratorValue, 1}},
|
InteropFuncPrice{Func: EnumeratorValue, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Iterator.Create")),
|
{emit.InteropNameToID([]byte("System.Iterator.Create")),
|
||||||
InteropFuncPrice{IteratorCreate, 1}},
|
InteropFuncPrice{Func: IteratorCreate, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Iterator.Concat")),
|
{emit.InteropNameToID([]byte("System.Iterator.Concat")),
|
||||||
InteropFuncPrice{IteratorConcat, 1}},
|
InteropFuncPrice{Func: IteratorConcat, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Iterator.Key")),
|
{emit.InteropNameToID([]byte("System.Iterator.Key")),
|
||||||
InteropFuncPrice{IteratorKey, 1}},
|
InteropFuncPrice{Func: IteratorKey, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Iterator.Keys")),
|
{emit.InteropNameToID([]byte("System.Iterator.Keys")),
|
||||||
InteropFuncPrice{IteratorKeys, 1}},
|
InteropFuncPrice{Func: IteratorKeys, Price: 1}},
|
||||||
{emit.InteropNameToID([]byte("System.Iterator.Values")),
|
{emit.InteropNameToID([]byte("System.Iterator.Values")),
|
||||||
InteropFuncPrice{IteratorValues, 1}},
|
InteropFuncPrice{Func: IteratorValues, Price: 1}},
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultVMInterop(id uint32) *InteropFuncPrice {
|
func getDefaultVMInterop(id uint32) *InteropFuncPrice {
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -111,11 +112,18 @@ func TestUT(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestingInterop(id uint32) *InteropFuncPrice {
|
func getTestingInterop(id uint32) *InteropFuncPrice {
|
||||||
if id == binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}) {
|
f := func(v *VM) error {
|
||||||
return &InteropFuncPrice{InteropFunc(func(v *VM) error {
|
|
||||||
v.estack.PushVal(stackitem.NewInterop(new(int)))
|
v.estack.PushVal(stackitem.NewInterop(new(int)))
|
||||||
return nil
|
return nil
|
||||||
}), 0}
|
}
|
||||||
|
switch id {
|
||||||
|
case binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}):
|
||||||
|
return &InteropFuncPrice{Func: f}
|
||||||
|
case binary.LittleEndian.Uint32([]byte{0x66, 0x66, 0x66, 0x66}):
|
||||||
|
return &InteropFuncPrice{
|
||||||
|
Func: f,
|
||||||
|
RequiredFlags: smartcontract.ReadOnly,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
16
pkg/vm/vm.go
16
pkg/vm/vm.go
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
|
@ -262,17 +263,23 @@ func (v *VM) Load(prog []byte) {
|
||||||
// will immediately push a new context created from this script to
|
// will immediately push a new context created from this script to
|
||||||
// the invocation stack and starts executing it.
|
// the invocation stack and starts executing it.
|
||||||
func (v *VM) LoadScript(b []byte) {
|
func (v *VM) LoadScript(b []byte) {
|
||||||
|
v.LoadScriptWithFlags(b, smartcontract.NoneFlag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadScriptWithFlags loads script and sets call flag to f.
|
||||||
|
func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) {
|
||||||
ctx := NewContext(b)
|
ctx := NewContext(b)
|
||||||
ctx.estack = v.estack
|
ctx.estack = v.estack
|
||||||
ctx.astack = v.astack
|
ctx.astack = v.astack
|
||||||
|
ctx.callFlag = f
|
||||||
v.istack.PushVal(ctx)
|
v.istack.PushVal(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadScriptWithHash if similar to the LoadScript method, but it also loads
|
// LoadScriptWithHash if similar to the LoadScriptWithFlags method, but it also loads
|
||||||
// given script hash directly into the Context to avoid its recalculations. It's
|
// given script hash directly into the Context to avoid its recalculations. It's
|
||||||
// up to user of this function to make sure the script and hash match each other.
|
// up to user of this function to make sure the script and hash match each other.
|
||||||
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160) {
|
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) {
|
||||||
v.LoadScript(b)
|
v.LoadScriptWithFlags(b, f)
|
||||||
ctx := v.Context()
|
ctx := v.Context()
|
||||||
ctx.scriptHash = hash
|
ctx.scriptHash = hash
|
||||||
}
|
}
|
||||||
|
@ -1253,6 +1260,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
||||||
case opcode.SYSCALL:
|
case opcode.SYSCALL:
|
||||||
interopID := GetInteropID(parameter)
|
interopID := GetInteropID(parameter)
|
||||||
ifunc := v.GetInteropByID(interopID)
|
ifunc := v.GetInteropByID(interopID)
|
||||||
|
if !v.Context().callFlag.Has(ifunc.RequiredFlags) {
|
||||||
|
panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags))
|
||||||
|
}
|
||||||
|
|
||||||
if ifunc == nil {
|
if ifunc == nil {
|
||||||
panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID))
|
panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID))
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
|
@ -22,10 +23,13 @@ import (
|
||||||
|
|
||||||
func fooInteropGetter(id uint32) *InteropFuncPrice {
|
func fooInteropGetter(id uint32) *InteropFuncPrice {
|
||||||
if id == emit.InteropNameToID([]byte("foo")) {
|
if id == emit.InteropNameToID([]byte("foo")) {
|
||||||
return &InteropFuncPrice{func(evm *VM) error {
|
return &InteropFuncPrice{
|
||||||
|
Func: func(evm *VM) error {
|
||||||
evm.Estack().PushVal(1)
|
evm.Estack().PushVal(1)
|
||||||
return nil
|
return nil
|
||||||
}, 1}
|
},
|
||||||
|
Price: 1,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -812,6 +816,31 @@ func TestSerializeInterop(t *testing.T) {
|
||||||
require.True(t, vm.HasFailed())
|
require.True(t, vm.HasFailed())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTestCallFlagsFunc(syscall []byte, flags smartcontract.CallFlag, result interface{}) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
script := append([]byte{byte(opcode.SYSCALL)}, syscall...)
|
||||||
|
v := New()
|
||||||
|
v.RegisterInteropGetter(getTestingInterop)
|
||||||
|
v.LoadScriptWithFlags(script, flags)
|
||||||
|
if result == nil {
|
||||||
|
checkVMFailed(t, v)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runVM(t, v)
|
||||||
|
require.Equal(t, result, v.PopResult())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallFlags(t *testing.T) {
|
||||||
|
noFlags := []byte{0x77, 0x77, 0x77, 0x77}
|
||||||
|
readOnly := []byte{0x66, 0x66, 0x66, 0x66}
|
||||||
|
t.Run("NoFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.NoneFlag, new(int)))
|
||||||
|
t.Run("ProvideFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.AllowCall, new(int)))
|
||||||
|
t.Run("NoFlagsSomeRequired", getTestCallFlagsFunc(readOnly, smartcontract.NoneFlag, nil))
|
||||||
|
t.Run("OnlyOneProvided", getTestCallFlagsFunc(readOnly, smartcontract.AllowCall, nil))
|
||||||
|
t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int)))
|
||||||
|
}
|
||||||
|
|
||||||
func callNTimes(n uint16) []byte {
|
func callNTimes(n uint16) []byte {
|
||||||
return makeProgram(
|
return makeProgram(
|
||||||
opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian
|
opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian
|
||||||
|
|
Loading…
Reference in a new issue