core: introduce CheckReturnState constants

At the moment we should have 3 possible options to check return state
during vm context unloading:
	* no check
	* ensure the stack is empty
	* ensure the stack is not empty

It is necessary to distinguish them because new _deploy method shouldn't
left anything on stack. Example: if we use _deploy method before some
ordinary contract method which returns one value. Without these changes
the contract invocation will fail due to 2 elements on stack left after
invocation (the first `null` element is from _deploy, the second element
is return-value from the ordinary contract method).
This commit is contained in:
Anna Shaleva 2020-10-12 14:32:27 +03:00
parent 659fb89beb
commit fe1f0a7245
5 changed files with 29 additions and 9 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
@ -51,12 +52,12 @@ func callExInternal(ic *interop.Context, h []byte, name string, args []stackitem
return errors.New("disallowed method call") return errors.New("disallowed method call")
} }
} }
return CallExInternal(ic, cs, name, args, f) return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty)
} }
// CallExInternal calls a contract with flags and can't be invoked directly by user. // CallExInternal calls a contract with flags and can't be invoked directly by user.
func CallExInternal(ic *interop.Context, cs *state.Contract, func CallExInternal(ic *interop.Context, cs *state.Contract,
name string, args []stackitem.Item, f smartcontract.CallFlag) error { name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState) error {
md := cs.Manifest.ABI.GetMethod(name) md := cs.Manifest.ABI.GetMethod(name)
if md == nil { if md == nil {
return fmt.Errorf("method '%s' not found", name) return fmt.Errorf("method '%s' not found", name)
@ -86,7 +87,7 @@ func CallExInternal(ic *interop.Context, cs *state.Contract,
// use Jump not Call here because context was loaded in LoadScript above. // use Jump not Call here because context was loaded in LoadScript above.
ic.VM.Jump(ic.VM.Context(), md.Offset) ic.VM.Jump(ic.VM.Context(), md.Offset)
} }
ic.VM.Context().CheckReturn = true ic.VM.Context().CheckReturn = checkReturn
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit) md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
if md != nil { if md != nil {

View file

@ -203,7 +203,7 @@ func callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) error {
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy) md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
if md != nil { if md != nil {
return contract.CallExInternal(ic, cs, manifest.MethodDeploy, return contract.CallExInternal(ic, cs, manifest.MethodDeploy,
[]stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All) []stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty)
} }
return nil return nil
} }

View file

@ -912,7 +912,7 @@ func TestContractCreateDeploy(t *testing.T) {
require.NoError(t, ic.VM.Run()) require.NoError(t, ic.VM.Run())
v.LoadScriptWithFlags(currCs.Script, smartcontract.All) v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All) err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, v.Run()) require.NoError(t, v.Run())
require.Equal(t, "create", v.Estack().Pop().String()) require.Equal(t, "create", v.Estack().Pop().String())
@ -933,7 +933,7 @@ func TestContractCreateDeploy(t *testing.T) {
require.NoError(t, v.Run()) require.NoError(t, v.Run())
v.LoadScriptWithFlags(currCs.Script, smartcontract.All) v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All) err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, v.Run()) require.NoError(t, v.Run())
require.Equal(t, "update", v.Estack().Pop().String()) require.Equal(t, "update", v.Estack().Pop().String())

View file

@ -46,9 +46,22 @@ type Context struct {
callFlag smartcontract.CallFlag callFlag smartcontract.CallFlag
// CheckReturn specifies if amount of return values needs to be checked. // CheckReturn specifies if amount of return values needs to be checked.
CheckReturn bool CheckReturn CheckReturnState
} }
// CheckReturnState represents possible states of stack after opcode.RET was processed.
type CheckReturnState byte
const (
// NoCheck performs no return values check.
NoCheck CheckReturnState = 0
// EnsureIsEmpty checks that stack is empty and panics if not.
EnsureIsEmpty CheckReturnState = 1
// EnsureNotEmpty checks that stack contains not more than 1 element and panics if not.
// It pushes stackitem.Null on stack in case if there's no elements.
EnsureNotEmpty CheckReturnState = 2
)
var errNoInstParam = errors.New("failed to read instruction parameter") var errNoInstParam = errors.New("failed to read instruction parameter")
// NewContext returns a new Context object. // NewContext returns a new Context object.

View file

@ -1409,7 +1409,13 @@ func (v *VM) unloadContext(ctx *Context) {
if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static {
ctx.static.Clear() ctx.static.Clear()
} }
if ctx.CheckReturn { switch ctx.CheckReturn {
case NoCheck:
case EnsureIsEmpty:
if currCtx != nil && ctx.estack.len != 0 {
panic("return value amount is > 0")
}
case EnsureNotEmpty:
if currCtx != nil && ctx.estack.len == 0 { if currCtx != nil && ctx.estack.len == 0 {
currCtx.estack.PushVal(stackitem.Null{}) currCtx.estack.PushVal(stackitem.Null{})
} else if ctx.estack.len > 1 { } else if ctx.estack.len > 1 {
@ -1471,7 +1477,7 @@ func (v *VM) Call(ctx *Context, offset int) {
// package. // package.
func (v *VM) call(ctx *Context, offset int) { func (v *VM) call(ctx *Context, offset int) {
newCtx := ctx.Copy() newCtx := ctx.Copy()
newCtx.CheckReturn = false newCtx.CheckReturn = NoCheck
newCtx.local = nil newCtx.local = nil
newCtx.arguments = nil newCtx.arguments = nil
newCtx.tryStack = NewStack("exception") newCtx.tryStack = NewStack("exception")