vm: check calling flags on syscall invocation

This commit is contained in:
Evgenii Stratonikov 2020-06-10 15:51:28 +03:00
parent 55ab7535be
commit bda94c74c3
7 changed files with 84 additions and 30 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
@ -64,7 +65,7 @@ func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM {
b, err := compiler.Compile(strings.NewReader(src)) b, err := compiler.Compile(strings.NewReader(src))
require.NoError(t, err) require.NoError(t, err)
v := core.SpawnVM(ic) v := core.SpawnVM(ic)
v.Load(b) v.LoadScriptWithFlags(b, smartcontract.All)
return v return v
} }

View file

@ -423,7 +423,7 @@ func contractCallEx(ic *interop.Context, v *vm.VM) error {
return contractCallExInternal(ic, v, h, method, args, flags) return contractCallExInternal(ic, v, h, method, args, flags)
} }
func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, _ smartcontract.CallFlag) error { func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stackitem.Item, args stackitem.Item, f smartcontract.CallFlag) error {
u, err := util.Uint160DecodeBytesBE(h) u, err := util.Uint160DecodeBytesBE(h)
if err != nil { if err != nil {
return errors.New("invalid contract hash") return errors.New("invalid contract hash")
@ -442,7 +442,7 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
return errors.New("disallowed method call") return errors.New("disallowed method call")
} }
} }
v.LoadScriptWithHash(cs.Script, u) v.LoadScriptWithHash(cs.Script, u, f)
v.Estack().PushVal(args) v.Estack().PushVal(args)
v.Estack().PushVal(method) v.Estack().PushVal(method)
return nil return nil

View file

@ -6,6 +6,7 @@ import (
"math/big" "math/big"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -39,6 +40,9 @@ type Context struct {
// Script hash of the prog. // Script hash of the prog.
scriptHash util.Uint160 scriptHash util.Uint160
// Call flags this context was created with.
callFlag smartcontract.CallFlag
} }
var errNoInstParam = errors.New("failed to read instruction parameter") var errNoInstParam = errors.New("failed to read instruction parameter")

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"sort" "sort"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
@ -16,6 +17,7 @@ type InteropFunc func(vm *VM) error
type InteropFuncPrice struct { type InteropFuncPrice struct {
Func InteropFunc Func InteropFunc
Price int Price int
RequiredFlags smartcontract.CallFlag
} }
// interopIDFuncPrice adds an ID to the InteropFuncPrice. // interopIDFuncPrice adds an ID to the InteropFuncPrice.
@ -30,31 +32,31 @@ type InteropGetterFunc func(uint32) *InteropFuncPrice
var defaultVMInterops = []interopIDFuncPrice{ var defaultVMInterops = []interopIDFuncPrice{
{emit.InteropNameToID([]byte("System.Runtime.Log")), {emit.InteropNameToID([]byte("System.Runtime.Log")),
InteropFuncPrice{runtimeLog, 1}}, InteropFuncPrice{Func: runtimeLog, Price: 1}},
{emit.InteropNameToID([]byte("System.Runtime.Notify")), {emit.InteropNameToID([]byte("System.Runtime.Notify")),
InteropFuncPrice{runtimeNotify, 1}}, InteropFuncPrice{Func: runtimeNotify, Price: 1}},
{emit.InteropNameToID([]byte("System.Runtime.Serialize")), {emit.InteropNameToID([]byte("System.Runtime.Serialize")),
InteropFuncPrice{RuntimeSerialize, 1}}, InteropFuncPrice{Func: RuntimeSerialize, Price: 1}},
{emit.InteropNameToID([]byte("System.Runtime.Deserialize")), {emit.InteropNameToID([]byte("System.Runtime.Deserialize")),
InteropFuncPrice{RuntimeDeserialize, 1}}, InteropFuncPrice{Func: RuntimeDeserialize, Price: 1}},
{emit.InteropNameToID([]byte("System.Enumerator.Create")), {emit.InteropNameToID([]byte("System.Enumerator.Create")),
InteropFuncPrice{EnumeratorCreate, 1}}, InteropFuncPrice{Func: EnumeratorCreate, Price: 1}},
{emit.InteropNameToID([]byte("System.Enumerator.Next")), {emit.InteropNameToID([]byte("System.Enumerator.Next")),
InteropFuncPrice{EnumeratorNext, 1}}, InteropFuncPrice{Func: EnumeratorNext, Price: 1}},
{emit.InteropNameToID([]byte("System.Enumerator.Concat")), {emit.InteropNameToID([]byte("System.Enumerator.Concat")),
InteropFuncPrice{EnumeratorConcat, 1}}, InteropFuncPrice{Func: EnumeratorConcat, Price: 1}},
{emit.InteropNameToID([]byte("System.Enumerator.Value")), {emit.InteropNameToID([]byte("System.Enumerator.Value")),
InteropFuncPrice{EnumeratorValue, 1}}, InteropFuncPrice{Func: EnumeratorValue, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Create")), {emit.InteropNameToID([]byte("System.Iterator.Create")),
InteropFuncPrice{IteratorCreate, 1}}, InteropFuncPrice{Func: IteratorCreate, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Concat")), {emit.InteropNameToID([]byte("System.Iterator.Concat")),
InteropFuncPrice{IteratorConcat, 1}}, InteropFuncPrice{Func: IteratorConcat, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Key")), {emit.InteropNameToID([]byte("System.Iterator.Key")),
InteropFuncPrice{IteratorKey, 1}}, InteropFuncPrice{Func: IteratorKey, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Keys")), {emit.InteropNameToID([]byte("System.Iterator.Keys")),
InteropFuncPrice{IteratorKeys, 1}}, InteropFuncPrice{Func: IteratorKeys, Price: 1}},
{emit.InteropNameToID([]byte("System.Iterator.Values")), {emit.InteropNameToID([]byte("System.Iterator.Values")),
InteropFuncPrice{IteratorValues, 1}}, InteropFuncPrice{Func: IteratorValues, Price: 1}},
} }
func getDefaultVMInterop(id uint32) *InteropFuncPrice { func getDefaultVMInterop(id uint32) *InteropFuncPrice {

View file

@ -17,6 +17,7 @@ import (
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -111,11 +112,18 @@ func TestUT(t *testing.T) {
} }
func getTestingInterop(id uint32) *InteropFuncPrice { func getTestingInterop(id uint32) *InteropFuncPrice {
if id == binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}) { f := func(v *VM) error {
return &InteropFuncPrice{InteropFunc(func(v *VM) error {
v.estack.PushVal(stackitem.NewInterop(new(int))) v.estack.PushVal(stackitem.NewInterop(new(int)))
return nil return nil
}), 0} }
switch id {
case binary.LittleEndian.Uint32([]byte{0x77, 0x77, 0x77, 0x77}):
return &InteropFuncPrice{Func: f}
case binary.LittleEndian.Uint32([]byte{0x66, 0x66, 0x66, 0x66}):
return &InteropFuncPrice{
Func: f,
RequiredFlags: smartcontract.ReadOnly,
}
} }
return nil return nil
} }

View file

@ -13,6 +13,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -262,17 +263,23 @@ func (v *VM) Load(prog []byte) {
// will immediately push a new context created from this script to // will immediately push a new context created from this script to
// the invocation stack and starts executing it. // the invocation stack and starts executing it.
func (v *VM) LoadScript(b []byte) { func (v *VM) LoadScript(b []byte) {
v.LoadScriptWithFlags(b, smartcontract.NoneFlag)
}
// LoadScriptWithFlags loads script and sets call flag to f.
func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) {
ctx := NewContext(b) ctx := NewContext(b)
ctx.estack = v.estack ctx.estack = v.estack
ctx.astack = v.astack ctx.astack = v.astack
ctx.callFlag = f
v.istack.PushVal(ctx) v.istack.PushVal(ctx)
} }
// LoadScriptWithHash if similar to the LoadScript method, but it also loads // LoadScriptWithHash if similar to the LoadScriptWithFlags method, but it also loads
// given script hash directly into the Context to avoid its recalculations. It's // given script hash directly into the Context to avoid its recalculations. It's
// up to user of this function to make sure the script and hash match each other. // up to user of this function to make sure the script and hash match each other.
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160) { func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) {
v.LoadScript(b) v.LoadScriptWithFlags(b, f)
ctx := v.Context() ctx := v.Context()
ctx.scriptHash = hash ctx.scriptHash = hash
} }
@ -1253,6 +1260,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case opcode.SYSCALL: case opcode.SYSCALL:
interopID := GetInteropID(parameter) interopID := GetInteropID(parameter)
ifunc := v.GetInteropByID(interopID) ifunc := v.GetInteropByID(interopID)
if !v.Context().callFlag.Has(ifunc.RequiredFlags) {
panic(fmt.Sprintf("missing call flags: %05b vs %05b", v.Context().callFlag, ifunc.RequiredFlags))
}
if ifunc == nil { if ifunc == nil {
panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID)) panic(fmt.Sprintf("interop hook (%q/0x%x) not registered", parameter, interopID))

View file

@ -12,6 +12,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
@ -22,10 +23,13 @@ import (
func fooInteropGetter(id uint32) *InteropFuncPrice { func fooInteropGetter(id uint32) *InteropFuncPrice {
if id == emit.InteropNameToID([]byte("foo")) { if id == emit.InteropNameToID([]byte("foo")) {
return &InteropFuncPrice{func(evm *VM) error { return &InteropFuncPrice{
Func: func(evm *VM) error {
evm.Estack().PushVal(1) evm.Estack().PushVal(1)
return nil return nil
}, 1} },
Price: 1,
}
} }
return nil return nil
} }
@ -812,6 +816,31 @@ func TestSerializeInterop(t *testing.T) {
require.True(t, vm.HasFailed()) require.True(t, vm.HasFailed())
} }
func getTestCallFlagsFunc(syscall []byte, flags smartcontract.CallFlag, result interface{}) func(t *testing.T) {
return func(t *testing.T) {
script := append([]byte{byte(opcode.SYSCALL)}, syscall...)
v := New()
v.RegisterInteropGetter(getTestingInterop)
v.LoadScriptWithFlags(script, flags)
if result == nil {
checkVMFailed(t, v)
return
}
runVM(t, v)
require.Equal(t, result, v.PopResult())
}
}
func TestCallFlags(t *testing.T) {
noFlags := []byte{0x77, 0x77, 0x77, 0x77}
readOnly := []byte{0x66, 0x66, 0x66, 0x66}
t.Run("NoFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.NoneFlag, new(int)))
t.Run("ProvideFlagsNoRequired", getTestCallFlagsFunc(noFlags, smartcontract.AllowCall, new(int)))
t.Run("NoFlagsSomeRequired", getTestCallFlagsFunc(readOnly, smartcontract.NoneFlag, nil))
t.Run("OnlyOneProvided", getTestCallFlagsFunc(readOnly, smartcontract.AllowCall, nil))
t.Run("AllFlagsProvided", getTestCallFlagsFunc(readOnly, smartcontract.ReadOnly, new(int)))
}
func callNTimes(n uint16) []byte { func callNTimes(n uint16) []byte {
return makeProgram( return makeProgram(
opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian opcode.PUSHINT16, opcode.Opcode(n), opcode.Opcode(n>>8), // little-endian