vm: check calling flags on syscall invocation
This commit is contained in:
parent
55ab7535be
commit
bda94c74c3
7 changed files with 84 additions and 30 deletions
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
|
@ -64,7 +65,7 @@ func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM {
|
|||
b, err := compiler.Compile(strings.NewReader(src))
|
||||
require.NoError(t, err)
|
||||
v := core.SpawnVM(ic)
|
||||
v.Load(b)
|
||||
v.LoadScriptWithFlags(b, smartcontract.All)
|
||||
return v
|
||||
}
|
||||
|
||||
|
|
|
@ -423,7 +423,7 @@ func contractCallEx(ic *interop.Context, v *vm.VM) error {
|
|||
return contractCallExInternal(ic, v, h, method, args, flags)
|
||||
}
|
||||
|
||||
func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, _ smartcontract.CallFlag) error {
|
||||
func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, f smartcontract.CallFlag) error {
|
||||
u, err := util.Uint160DecodeBytesBE(h)
|
||||
if err != nil {
|
||||
return errors.New("invalid contract hash")
|
||||
|
@ -442,7 +442,7 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
|
|||
return errors.New("disallowed method call")
|
||||
}
|
||||
}
|
||||
v.LoadScriptWithHash(cs.Script, u)
|
||||
v.LoadScriptWithHash(cs.Script, u, f)
|
||||
v.Estack().PushVal(args)
|
||||
v.Estack().PushVal(method)
|
||||
return nil
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"math/big"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
|
@ -39,6 +40,9 @@ type Context struct {
|
|||
|
||||
// Script hash of the prog.
|
||||
scriptHash util.Uint160
|
||||
|
||||
// Call flags this context was created with.
|
||||
callFlag smartcontract.CallFlag
|
||||
}
|
||||
|
||||
var errNoInstParam = errors.New("failed to read instruction parameter")
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
)
|
||||
|
@ -14,8 +15,9 @@ type InteropFunc func(vm *VM) error
|
|||
|
||||
// InteropFuncPrice represents an interop function with a price.
|
||||
type InteropFuncPrice struct {
|
||||
Func InteropFunc
|
||||
Price int
|
||||
Func InteropFunc
|
||||
Price int
|
||||
RequiredFlags smartcontract.CallFlag
|
||||
}
|
||||
|
||||
// interopIDFuncPrice adds an ID to the InteropFuncPrice.
|
||||
|
@ -30,31 +32,31 @@ type InteropGetterFunc func(uint32) *InteropFuncPrice
|
|||
|
||||
var defaultVMInterops = []interopIDFuncPrice{
|
||||
{emit.InteropNameToID([]byte("System.Runtime.Log")),
|
||||
InteropFuncPrice{runtimeLog, 1}},
|
||||
InteropFuncPrice{Func: runtimeLog, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Runtime.Notify")),
|
||||
InteropFuncPrice{runtimeNotify, 1}},
|
||||
InteropFuncPrice{Func: runtimeNotify, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Runtime.Serialize")),
|
||||
InteropFuncPrice{RuntimeSerialize, 1}},
|
||||
InteropFuncPrice{Func: RuntimeSerialize, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Runtime.Deserialize")),
|
||||
InteropFuncPrice{RuntimeDeserialize, 1}},
|
||||
InteropFuncPrice{Func: RuntimeDeserialize, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Enumerator.Create")),
|
||||
InteropFuncPrice{EnumeratorCreate, 1}},
|
||||
InteropFuncPrice{Func: EnumeratorCreate, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Enumerator.Next")),
|
||||
InteropFuncPrice{EnumeratorNext, 1}},
|
||||
InteropFuncPrice{Func: EnumeratorNext, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Enumerator.Concat")),
|
||||
InteropFuncPrice{EnumeratorConcat, 1}},
|
||||
InteropFuncPrice{Func: EnumeratorConcat, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Enumerator.Value")),
|
||||
InteropFuncPrice{EnumeratorValue, 1}},
|
||||
InteropFuncPrice{Func: EnumeratorValue, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Iterator.Create")),
|
||||
InteropFuncPrice{IteratorCreate, 1}},
|
||||
InteropFuncPrice{Func: IteratorCreate, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Iterator.Concat")),
|
||||
InteropFuncPrice{IteratorConcat, 1}},
|
||||
InteropFuncPrice{Func: IteratorConcat, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Iterator.Key")),
|
||||
InteropFuncPrice{IteratorKey, 1}},
|
||||
InteropFuncPrice{Func: IteratorKey, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Iterator.Keys")),
|
||||
InteropFuncPrice{IteratorKeys, 1}},
|
||||
InteropFuncPrice{Func: IteratorKeys, Price: 1}},
|
||||
{emit.InteropNameToID([]byte("System.Iterator.Values")),
|
||||
InteropFuncPrice{IteratorValues, 1}},
|
||||
InteropFuncPrice{Func: IteratorValues, Price: 1}},
|
||||
}
|
||||
|
||||
func getDefaultVMInterop(id uint32) *InteropFuncPrice {
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -111,11 +112,18 @@ func TestUT(t *testing.T) {
|
|||
}
|
||||
|
||||
func getTestingInterop(id uint32) *InteropFuncPrice {
|
||||
if id == binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}) {
|
||||
return &InteropFuncPrice{InteropFunc(func(v *VM) error {
|
||||
v.estack.PushVal(stackitem.NewInterop(new(int)))
|
||||
return nil
|
||||
}), 0}
|
||||
f := func(v *VM) error {
|
||||
v.estack.PushVal(stackitem.NewInterop(new(int)))
|
||||
return nil
|
||||
}
|
||||
switch id {
|
||||
case binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}):
|
||||
return &InteropFuncPrice{Func: f}
|
||||
case binary.LittleEndian.Uint32([]byte{0x66, 0x66, 0x66, 0x66}):
|
||||
return &InteropFuncPrice{
|
||||
Func: f,
|
||||
RequiredFlags: smartcontract.ReadOnly,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
16
pkg/vm/vm.go
16
pkg/vm/vm.go
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
|
@ -262,17 +263,23 @@ func (v *VM) Load(prog []byte) {
|
|||
// will immediately push a new context created from this script to
|
||||
// the invocation stack and starts executing it.
|
||||
func (v *VM) LoadScript(b []byte) {
|
||||
v.LoadScriptWithFlags(b, smartcontract.NoneFlag)
|
||||
}
|
||||
|
||||
// LoadScriptWithFlags loads script and sets call flag to f.
|
||||
func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) {
|
||||
ctx := NewContext(b)
|
||||
ctx.estack = v.estack
|
||||
ctx.astack = v.astack
|
||||
ctx.callFlag = f
|
||||
v.istack.PushVal(ctx)
|
||||
}
|
||||
|
||||
// LoadScriptWithHash if similar to the LoadScript method, but it also loads
|
||||
// LoadScriptWithHash if similar to the LoadScriptWithFlags method, but it also loads
|
||||
// given script hash directly into the Context to avoid its recalculations. It's
|
||||
// up to user of this function to make sure the script and hash match each other.
|
||||
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160) {
|
||||
v.LoadScript(b)
|
||||
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) {
|
||||
v.LoadScriptWithFlags(b, f)
|
||||
ctx := v.Context()
|
||||
ctx.scriptHash = hash
|
||||
}
|
||||
|
@ -1253,6 +1260,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
|||
case opcode.SYSCALL:
|
||||
interopID := GetInteropID(parameter)
|
||||
ifunc := v.GetInteropByID(interopID)
|
||||
if !v.Context().callFlag.Has(ifunc.RequiredFlags) {
|
||||
panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags))
|
||||
}
|
||||
|
||||
if ifunc == nil {
|
||||
panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID))
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
|
@ -22,10 +23,13 @@ import (
|
|||
|
||||
func fooInteropGetter(id uint32) *InteropFuncPrice {
|
||||
if id == emit.InteropNameToID([]byte("foo")) {
|
||||
return &InteropFuncPrice{func(evm *VM) error {
|
||||
evm.Estack().PushVal(1)
|
||||
return nil
|
||||
}, 1}
|
||||
return &InteropFuncPrice{
|
||||
Func: func(evm *VM) error {
|
||||
evm.Estack().PushVal(1)
|
||||
return nil
|
||||
},
|
||||
Price: 1,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -812,6 +816,31 @@ func TestSerializeInterop(t *testing.T) {
|
|||
require.True(t, vm.HasFailed())
|
||||
}
|
||||
|
||||
func getTestCallFlagsFunc(syscall []byte, flags smartcontract.CallFlag, result interface{}) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
script := append([]byte{byte(opcode.SYSCALL)}, syscall...)
|
||||
v := New()
|
||||
v.RegisterInteropGetter(getTestingInterop)
|
||||
v.LoadScriptWithFlags(script, flags)
|
||||
if result == nil {
|
||||
checkVMFailed(t, v)
|
||||
return
|
||||
}
|
||||
runVM(t, v)
|
||||
require.Equal(t, result, v.PopResult())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallFlags(t *testing.T) {
|
||||
noFlags := []byte{0x77, 0x77, 0x77, 0x77}
|
||||
readOnly := []byte{0x66, 0x66, 0x66, 0x66}
|
||||
t.Run("NoFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.NoneFlag, new(int)))
|
||||
t.Run("ProvideFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.AllowCall, new(int)))
|
||||
t.Run("NoFlagsSomeRequired", getTestCallFlagsFunc(readOnly, smartcontract.NoneFlag, nil))
|
||||
t.Run("OnlyOneProvided", getTestCallFlagsFunc(readOnly, smartcontract.AllowCall, nil))
|
||||
t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int)))
|
||||
}
|
||||
|
||||
func callNTimes(n uint16) []byte {
|
||||
return makeProgram(
|
||||
opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian
|
||||
|
|
Loading…
Reference in a new issue