From 73f888f02e02badd78fe462eebbf022340255cc7 Mon Sep 17 00:00:00 2001 From: Evgeniy Stratonikov Date: Tue, 26 Jan 2021 17:37:34 +0300 Subject: [PATCH 1/2] core: allow to overload contract methods Multiple methods with different parameter count can co-exist. --- cli/wallet/wallet.go | 2 +- pkg/core/blockchain.go | 4 +- pkg/core/interop/contract/call.go | 8 +-- pkg/core/interop_system_test.go | 51 +++++++++++++++---- pkg/core/native/management.go | 2 +- pkg/core/native_management_test.go | 3 +- pkg/smartcontract/manifest/manifest.go | 4 +- pkg/smartcontract/manifest/standard/comply.go | 7 +-- .../manifest/standard/comply_test.go | 12 ++--- pkg/vm/cli/cli.go | 14 ++--- 10 files changed, 67 insertions(+), 40 deletions(-) diff --git a/cli/wallet/wallet.go b/cli/wallet/wallet.go index f8dc0f44c..911e9afa9 100644 --- a/cli/wallet/wallet.go +++ b/cli/wallet/wallet.go @@ -419,7 +419,7 @@ func importDeployed(ctx *cli.Context) error { if err != nil { return cli.NewExitError(fmt.Errorf("can't fetch contract info: %w", err), 1) } - md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify) + md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1) if md == nil { return cli.NewExitError("contract has no `verify` method", 1) } diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index f143259e8..94052597c 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1667,11 +1667,11 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160, if err != nil { return ErrUnknownVerificationContract } - md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify) + md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1) if md == nil { return ErrInvalidVerificationContract } - initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit) + initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0) v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates) v.Context().NEF = &cs.NEF v.Jump(v.Context(), md.Offset) diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 6603f337d..b669241e6 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -55,7 +55,7 @@ func Call(ic *interop.Context) error { if strings.HasPrefix(method, "_") { return errors.New("invalid method name (starts with '_')") } - md := cs.Manifest.ABI.GetMethod(method) + md := cs.Manifest.ABI.GetMethod(method, len(args)) if md == nil { return errors.New("method not found") } @@ -68,7 +68,7 @@ func Call(ic *interop.Context) error { func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, hasReturn bool, args []stackitem.Item) error { - md := cs.Manifest.ABI.GetMethod(name) + md := cs.Manifest.ABI.GetMethod(name, len(args)) if md.Safe { f &^= callflag.WriteStates } else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() { @@ -85,7 +85,7 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl // callExFromNative calls a contract with flags using provided calling hash. func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error { - md := cs.Manifest.ABI.GetMethod(name) + md := cs.Manifest.ABI.GetMethod(name, len(args)) if md == nil { return fmt.Errorf("method '%s' not found", name) } @@ -119,7 +119,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra ic.VM.Context().RetCount = 0 } - md = cs.Manifest.ABI.GetMethod(manifest.MethodInit) + md = cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0) if md != nil { ic.VM.Call(ic.VM.Context(), md.Offset) } diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index b30e80a3f..88ce4fee1 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -461,6 +461,8 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) { emit.Opcodes(w.BinWriter, opcode.ABORT) addOff := w.Len() emit.Opcodes(w.BinWriter, opcode.ADD, opcode.RET) + addMultiOff := w.Len() + emit.Opcodes(w.BinWriter, opcode.ADD, opcode.ADD, opcode.RET) ret7Off := w.Len() emit.Opcodes(w.BinWriter, opcode.PUSH7, opcode.RET) dropOff := w.Len() @@ -533,6 +535,16 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) { }, ReturnType: smartcontract.IntegerType, }, + { + Name: "add", + Offset: addMultiOff, + Parameters: []manifest.Parameter{ + manifest.NewParameter("addend1", smartcontract.IntegerType), + manifest.NewParameter("addend2", smartcontract.IntegerType), + manifest.NewParameter("addend3", smartcontract.IntegerType), + }, + ReturnType: smartcontract.IntegerType, + }, { Name: "ret7", Offset: ret7Off, @@ -731,16 +743,31 @@ func TestContractCall(t *testing.T) { addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)}) t.Run("Good", func(t *testing.T) { - loadScript(ic, currScript, 42) - ic.VM.Estack().PushVal(addArgs) - ic.VM.Estack().PushVal(callflag.All) - ic.VM.Estack().PushVal("add") - ic.VM.Estack().PushVal(h.BytesBE()) - require.NoError(t, contract.Call(ic)) - require.NoError(t, ic.VM.Run()) - require.Equal(t, 2, ic.VM.Estack().Len()) - require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value()) - require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value()) + t.Run("2 arguments", func(t *testing.T) { + loadScript(ic, currScript, 42) + ic.VM.Estack().PushVal(addArgs) + ic.VM.Estack().PushVal(callflag.All) + ic.VM.Estack().PushVal("add") + ic.VM.Estack().PushVal(h.BytesBE()) + require.NoError(t, contract.Call(ic)) + require.NoError(t, ic.VM.Run()) + require.Equal(t, 2, ic.VM.Estack().Len()) + require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value()) + require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value()) + }) + t.Run("3 arguments", func(t *testing.T) { + loadScript(ic, currScript, 42) + ic.VM.Estack().PushVal(stackitem.NewArray( + append(addArgs.Value().([]stackitem.Item), stackitem.Make(3)))) + ic.VM.Estack().PushVal(callflag.All) + ic.VM.Estack().PushVal("add") + ic.VM.Estack().PushVal(h.BytesBE()) + require.NoError(t, contract.Call(ic)) + require.NoError(t, ic.VM.Run()) + require.Equal(t, 2, ic.VM.Estack().Len()) + require.Equal(t, big.NewInt(6), ic.VM.Estack().Pop().Value()) + require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value()) + }) }) t.Run("CallExInvalidFlag", func(t *testing.T) { @@ -778,6 +805,10 @@ func TestContractCall(t *testing.T) { t.Run("Arguments", runInvalid(1, "add", h.BytesBE())) t.Run("NotEnoughArguments", runInvalid( stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE())) + t.Run("TooMuchArguments", runInvalid( + stackitem.NewArray([]stackitem.Item{ + stackitem.Make(1), stackitem.Make(2), stackitem.Make(3), stackitem.Make(4)}), + "add", h.BytesBE())) }) t.Run("ReturnValues", func(t *testing.T) { diff --git a/pkg/core/native/management.go b/pkg/core/native/management.go index f25843c81..7cc9b3f21 100644 --- a/pkg/core/native/management.go +++ b/pkg/core/native/management.go @@ -379,7 +379,7 @@ func (m *Management) setMinimumDeploymentFee(ic *interop.Context, args []stackit } func (m *Management) callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) { - md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy) + md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy, 1) if md != nil { err := contract.CallFromNative(ic, m.Hash, cs, manifest.MethodDeploy, []stackitem.Item{stackitem.NewBool(isUpdate)}, false) diff --git a/pkg/core/native_management_test.go b/pkg/core/native_management_test.go index eb94d5bf9..2b8dd5ec0 100644 --- a/pkg/core/native_management_test.go +++ b/pkg/core/native_management_test.go @@ -227,9 +227,8 @@ func TestContractDeploy(t *testing.T) { Offset: 0, Parameters: []manifest.Parameter{ manifest.NewParameter("isUpdate", smartcontract.BoolType), - manifest.NewParameter("param", smartcontract.IntegerType), }, - ReturnType: smartcontract.VoidType, + ReturnType: smartcontract.ArrayType, }, } nefD, err := nef.NewFile(deployScript) diff --git a/pkg/smartcontract/manifest/manifest.go b/pkg/smartcontract/manifest/manifest.go index c095fa4a8..e65175e85 100644 --- a/pkg/smartcontract/manifest/manifest.go +++ b/pkg/smartcontract/manifest/manifest.go @@ -76,9 +76,9 @@ func DefaultManifest(name string) *Manifest { } // GetMethod returns methods with the specified name. -func (a *ABI) GetMethod(name string) *Method { +func (a *ABI) GetMethod(name string, paramCount int) *Method { for i := range a.Methods { - if a.Methods[i].Name == name { + if a.Methods[i].Name == name && (paramCount == -1 || len(a.Methods[i].Parameters) == paramCount) { return &a.Methods[i] } } diff --git a/pkg/smartcontract/manifest/standard/comply.go b/pkg/smartcontract/manifest/standard/comply.go index b629723bc..1d03c55f1 100644 --- a/pkg/smartcontract/manifest/standard/comply.go +++ b/pkg/smartcontract/manifest/standard/comply.go @@ -40,15 +40,12 @@ func Check(m *manifest.Manifest, standards ...string) error { func Comply(m, st *manifest.Manifest) error { for _, stm := range st.ABI.Methods { name := stm.Name - md := m.ABI.GetMethod(name) + md := m.ABI.GetMethod(name, len(stm.Parameters)) if md == nil { - return fmt.Errorf("%w: '%s'", ErrMethodMissing, name) + return fmt.Errorf("%w: '%s' with %d parameters", ErrMethodMissing, name, len(stm.Parameters)) } else if stm.ReturnType != md.ReturnType { return fmt.Errorf("%w: '%s' (expected %s, got %s)", ErrInvalidReturnType, name, stm.ReturnType, md.ReturnType) - } else if len(stm.Parameters) != len(md.Parameters) { - return fmt.Errorf("%w: '%s' (expected %d, got %d)", ErrInvalidParameterCount, - name, len(stm.Parameters), len(md.Parameters)) } for i := range stm.Parameters { if stm.Parameters[i].Type != md.Parameters[i].Type { diff --git a/pkg/smartcontract/manifest/standard/comply_test.go b/pkg/smartcontract/manifest/standard/comply_test.go index 8a1851d42..da05947c5 100644 --- a/pkg/smartcontract/manifest/standard/comply_test.go +++ b/pkg/smartcontract/manifest/standard/comply_test.go @@ -37,14 +37,14 @@ func fooMethodBarEvent() *manifest.Manifest { func TestComplyMissingMethod(t *testing.T) { m := fooMethodBarEvent() - m.ABI.GetMethod("foo").Name = "notafoo" + m.ABI.GetMethod("foo", -1).Name = "notafoo" err := Comply(m, fooMethodBarEvent()) require.True(t, errors.Is(err, ErrMethodMissing)) } func TestComplyInvalidReturnType(t *testing.T) { m := fooMethodBarEvent() - m.ABI.GetMethod("foo").ReturnType = smartcontract.VoidType + m.ABI.GetMethod("foo", -1).ReturnType = smartcontract.VoidType err := Comply(m, fooMethodBarEvent()) require.True(t, errors.Is(err, ErrInvalidReturnType)) } @@ -52,10 +52,10 @@ func TestComplyInvalidReturnType(t *testing.T) { func TestComplyMethodParameterCount(t *testing.T) { t.Run("Method", func(t *testing.T) { m := fooMethodBarEvent() - f := m.ABI.GetMethod("foo") + f := m.ABI.GetMethod("foo", -1) f.Parameters = append(f.Parameters, manifest.Parameter{Type: smartcontract.BoolType}) err := Comply(m, fooMethodBarEvent()) - require.True(t, errors.Is(err, ErrInvalidParameterCount)) + require.True(t, errors.Is(err, ErrMethodMissing)) }) t.Run("Event", func(t *testing.T) { m := fooMethodBarEvent() @@ -69,7 +69,7 @@ func TestComplyMethodParameterCount(t *testing.T) { func TestComplyParameterType(t *testing.T) { t.Run("Method", func(t *testing.T) { m := fooMethodBarEvent() - m.ABI.GetMethod("foo").Parameters[0].Type = smartcontract.InteropInterfaceType + m.ABI.GetMethod("foo", -1).Parameters[0].Type = smartcontract.InteropInterfaceType err := Comply(m, fooMethodBarEvent()) require.True(t, errors.Is(err, ErrInvalidParameterType)) }) @@ -90,7 +90,7 @@ func TestMissingEvent(t *testing.T) { func TestSafeFlag(t *testing.T) { m := fooMethodBarEvent() - m.ABI.GetMethod("foo").Safe = false + m.ABI.GetMethod("foo", -1).Safe = false err := Comply(m, fooMethodBarEvent()) require.True(t, errors.Is(err, ErrSafeMethodMismatch)) } diff --git a/pkg/vm/cli/cli.go b/pkg/vm/cli/cli.go index 4582482bf..3e3495dd1 100644 --- a/pkg/vm/cli/cli.go +++ b/pkg/vm/cli/cli.go @@ -389,25 +389,25 @@ func handleRun(c *ishell.Context) { runCurrent = c.Args[0] != "_" ) + params, err = parseArgs(c.Args[1:]) + if err != nil { + c.Err(err) + return + } if runCurrent { - md := m.ABI.GetMethod(c.Args[0]) + md := m.ABI.GetMethod(c.Args[0], len(params)) if md == nil { c.Err(fmt.Errorf("%w: method not found", ErrInvalidParameter)) return } offset = md.Offset } - params, err = parseArgs(c.Args[1:]) - if err != nil { - c.Err(err) - return - } for i := len(params) - 1; i >= 0; i-- { v.Estack().PushVal(params[i]) } if runCurrent { v.Jump(v.Context(), offset) - if initMD := m.ABI.GetMethod(manifest.MethodInit); initMD != nil { + if initMD := m.ABI.GetMethod(manifest.MethodInit, 0); initMD != nil { v.Call(v.Context(), initMD.Offset) } } From dd1e2cefe4635d0a2dec7e1d999c57d46108aa0f Mon Sep 17 00:00:00 2001 From: Evgeniy Stratonikov Date: Tue, 26 Jan 2021 18:00:08 +0300 Subject: [PATCH 2/2] core,cli: disallow verify methods with non-bool returns --- cli/wallet/wallet.go | 5 +++-- pkg/core/blockchain.go | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cli/wallet/wallet.go b/cli/wallet/wallet.go index 911e9afa9..915321559 100644 --- a/cli/wallet/wallet.go +++ b/cli/wallet/wallet.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/native/nativenames" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/address" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/wallet" @@ -420,8 +421,8 @@ func importDeployed(ctx *cli.Context) error { return cli.NewExitError(fmt.Errorf("can't fetch contract info: %w", err), 1) } md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1) - if md == nil { - return cli.NewExitError("contract has no `verify` method", 1) + if md == nil || md.ReturnType != smartcontract.BoolType { + return cli.NewExitError("contract has no `verify` method with boolean return", 1) } acc.Address = address.Uint160ToString(cs.Hash) acc.Contract.Script = cs.NEF.Script diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 94052597c..0e4983d4b 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -28,6 +28,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" @@ -1668,7 +1669,7 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160, return ErrUnknownVerificationContract } md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1) - if md == nil { + if md == nil || md.ReturnType != smartcontract.BoolType { return ErrInvalidVerificationContract } initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)