From bf015994300b237f022ffa9132d3f4ec4fa2d90f Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Mon, 10 Aug 2020 11:52:32 +0300 Subject: [PATCH] vm: check return value on context unload When calling external contracts we expect exactly 1 value to be on stack. For methods returning nothing, `Null` value is pushed, otherwise it is an error.` --- pkg/core/interop_system.go | 1 + pkg/core/interop_system_test.go | 36 +++++++++++++++++++++++++++++++++ pkg/core/native/contract.go | 1 + pkg/vm/context.go | 3 +++ pkg/vm/vm.go | 8 ++++++++ pkg/vm/vm_test.go | 8 ++++++++ 6 files changed, 57 insertions(+) diff --git a/pkg/core/interop_system.go b/pkg/core/interop_system.go index abb396785..300f59556 100644 --- a/pkg/core/interop_system.go +++ b/pkg/core/interop_system.go @@ -519,6 +519,7 @@ func contractCallExInternal(ic *interop.Context, h []byte, name string, args []s } // use Jump not Call here because context was loaded in LoadScript above. ic.VM.Jump(ic.VM.Context(), md.Offset) + ic.VM.Context().CheckReturn = true } md = cs.Manifest.ABI.GetMethod(manifest.MethodInit) diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index 0a1776b45..29e1ddcbc 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -335,6 +335,8 @@ func getTestContractState() (*state.Contract, *state.Contract) { byte(opcode.DROP), byte(opcode.RET), byte(opcode.INITSSLOT), 1, byte(opcode.PUSH3), byte(opcode.STSFLD0), byte(opcode.RET), byte(opcode.LDSFLD0), byte(opcode.ADD), byte(opcode.RET), + byte(opcode.PUSH1), byte(opcode.PUSH2), byte(opcode.RET), + byte(opcode.RET), } h := hash.Hash160(script) m := manifest.NewManifest(h) @@ -372,6 +374,16 @@ func getTestContractState() (*state.Contract, *state.Contract) { }, ReturnType: smartcontract.IntegerType, }, + { + Name: "invalidReturn", + Offset: 15, + ReturnType: smartcontract.IntegerType, + }, + { + Name: "justReturn", + Offset: 18, + ReturnType: smartcontract.IntegerType, + }, } cs := &state.Contract{ Script: script, @@ -385,6 +397,8 @@ func getTestContractState() (*state.Contract, *state.Contract) { perm.Methods.Add("add") perm.Methods.Add("drop") perm.Methods.Add("add3") + perm.Methods.Add("invalidReturn") + perm.Methods.Add("justReturn") m.Permissions = append(m.Permissions, *perm) return cs, &state.Contract{ @@ -465,6 +479,28 @@ func TestContractCall(t *testing.T) { stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE())) }) + t.Run("ReturnValues", func(t *testing.T) { + t.Run("Many", func(t *testing.T) { + loadScript(ic, currScript, 42) + ic.VM.Estack().PushVal(stackitem.NewArray(nil)) + ic.VM.Estack().PushVal("invalidReturn") + ic.VM.Estack().PushVal(h.BytesBE()) + require.NoError(t, contractCall(ic)) + require.Error(t, ic.VM.Run()) + }) + t.Run("Void", func(t *testing.T) { + loadScript(ic, currScript, 42) + ic.VM.Estack().PushVal(stackitem.NewArray(nil)) + ic.VM.Estack().PushVal("justReturn") + ic.VM.Estack().PushVal(h.BytesBE()) + require.NoError(t, contractCall(ic)) + require.NoError(t, ic.VM.Run()) + require.Equal(t, 2, ic.VM.Estack().Len()) + require.Equal(t, stackitem.Null{}, ic.VM.Estack().Pop().Item()) + require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value()) + }) + }) + t.Run("IsolatedStack", func(t *testing.T) { loadScript(ic, currScript, 42) ic.VM.Estack().PushVal(stackitem.NewArray(nil)) diff --git a/pkg/core/native/contract.go b/pkg/core/native/contract.go index a77c9a963..517a4f4ac 100644 --- a/pkg/core/native/contract.go +++ b/pkg/core/native/contract.go @@ -61,6 +61,7 @@ func (cs *Contracts) GetPersistScript() []byte { emit.Opcode(w.BinWriter, opcode.NEWARRAY) emit.String(w.BinWriter, "onPersist") emit.AppCall(w.BinWriter, md.Hash) + emit.Opcode(w.BinWriter, opcode.DROP) } cs.persistScript = w.Bytes() return cs.persistScript diff --git a/pkg/vm/context.go b/pkg/vm/context.go index ca56b42e1..aff90f0a5 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -44,6 +44,9 @@ type Context struct { // Call flags this context was created with. callFlag smartcontract.CallFlag + + // CheckReturn specifies if amount of return values needs to be checked. + CheckReturn bool } var errNoInstParam = errors.New("failed to read instruction parameter") diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index edb229414..14a3a435e 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1391,6 +1391,13 @@ func (v *VM) unloadContext(ctx *Context) { if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { ctx.static.Clear() } + if ctx.CheckReturn { + if currCtx != nil && ctx.estack.len == 0 { + currCtx.estack.PushVal(stackitem.Null{}) + } else if ctx.estack.len > 1 { + panic("return value amount is > 1") + } + } } // getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks. @@ -1437,6 +1444,7 @@ func (v *VM) Jump(ctx *Context, offset int) { // pushes new context to the invocation state func (v *VM) Call(ctx *Context, offset int) { newCtx := ctx.Copy() + newCtx.CheckReturn = false newCtx.local = nil newCtx.arguments = nil v.istack.PushVal(newCtx) diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 5cd5e0d18..cb55e949b 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -959,6 +959,14 @@ func TestCALLA(t *testing.T) { t.Run("Good", getTestFuncForVM(prog, 5, stackitem.NewPointer(4, prog))) } +func TestCALL(t *testing.T) { + prog := makeProgram( + opcode.CALL, 4, opcode.ADD, opcode.RET, + opcode.CALL, 3, opcode.RET, + opcode.PUSH1, opcode.PUSH2, opcode.RET) + runWithArgs(t, prog, 3) +} + func TestNOT(t *testing.T) { prog := makeProgram(opcode.NOT) t.Run("Bool", getTestFuncForVM(prog, true, false))