Merge pull request #1689 from nspcc-dev/overload

core: allow to overload contract methods
This commit is contained in:
Roman Khimov 2021-01-27 15:11:54 +03:00 committed by GitHub
commit f1792b32b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 72 additions and 43 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/native/nativenames" "github.com/nspcc-dev/neo-go/pkg/core/native/nativenames"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/nspcc-dev/neo-go/pkg/wallet"
@ -419,9 +420,9 @@ func importDeployed(ctx *cli.Context) error {
if err != nil { if err != nil {
return cli.NewExitError(fmt.Errorf("can't fetch contract info: %w", err), 1) return cli.NewExitError(fmt.Errorf("can't fetch contract info: %w", err), 1)
} }
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify) md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
if md == nil { if md == nil || md.ReturnType != smartcontract.BoolType {
return cli.NewExitError("contract has no `verify` method", 1) return cli.NewExitError("contract has no `verify` method with boolean return", 1)
} }
acc.Address = address.Uint160ToString(cs.Hash) acc.Address = address.Uint160ToString(cs.Hash)
acc.Contract.Script = cs.NEF.Script acc.Contract.Script = cs.NEF.Script

View file

@ -28,6 +28,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
@ -1667,11 +1668,11 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160,
if err != nil { if err != nil {
return ErrUnknownVerificationContract return ErrUnknownVerificationContract
} }
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify) md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
if md == nil { if md == nil || md.ReturnType != smartcontract.BoolType {
return ErrInvalidVerificationContract return ErrInvalidVerificationContract
} }
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit) initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates) v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates)
v.Context().NEF = &cs.NEF v.Context().NEF = &cs.NEF
v.Jump(v.Context(), md.Offset) v.Jump(v.Context(), md.Offset)

View file

@ -55,7 +55,7 @@ func Call(ic *interop.Context) error {
if strings.HasPrefix(method, "_") { if strings.HasPrefix(method, "_") {
return errors.New("invalid method name (starts with '_')") return errors.New("invalid method name (starts with '_')")
} }
md := cs.Manifest.ABI.GetMethod(method) md := cs.Manifest.ABI.GetMethod(method, len(args))
if md == nil { if md == nil {
return errors.New("method not found") return errors.New("method not found")
} }
@ -68,7 +68,7 @@ func Call(ic *interop.Context) error {
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
hasReturn bool, args []stackitem.Item) error { hasReturn bool, args []stackitem.Item) error {
md := cs.Manifest.ABI.GetMethod(name) md := cs.Manifest.ABI.GetMethod(name, len(args))
if md.Safe { if md.Safe {
f &^= callflag.WriteStates f &^= callflag.WriteStates
} else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() { } else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() {
@ -85,7 +85,7 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl
// callExFromNative calls a contract with flags using provided calling hash. // callExFromNative calls a contract with flags using provided calling hash.
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error { name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error {
md := cs.Manifest.ABI.GetMethod(name) md := cs.Manifest.ABI.GetMethod(name, len(args))
if md == nil { if md == nil {
return fmt.Errorf("method '%s' not found", name) return fmt.Errorf("method '%s' not found", name)
} }
@ -119,7 +119,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra
ic.VM.Context().RetCount = 0 ic.VM.Context().RetCount = 0
} }
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit) md = cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
if md != nil { if md != nil {
ic.VM.Call(ic.VM.Context(), md.Offset) ic.VM.Call(ic.VM.Context(), md.Offset)
} }

View file

@ -461,6 +461,8 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
emit.Opcodes(w.BinWriter, opcode.ABORT) emit.Opcodes(w.BinWriter, opcode.ABORT)
addOff := w.Len() addOff := w.Len()
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.RET) emit.Opcodes(w.BinWriter, opcode.ADD, opcode.RET)
addMultiOff := w.Len()
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.ADD, opcode.RET)
ret7Off := w.Len() ret7Off := w.Len()
emit.Opcodes(w.BinWriter, opcode.PUSH7, opcode.RET) emit.Opcodes(w.BinWriter, opcode.PUSH7, opcode.RET)
dropOff := w.Len() dropOff := w.Len()
@ -533,6 +535,16 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
}, },
ReturnType: smartcontract.IntegerType, ReturnType: smartcontract.IntegerType,
}, },
{
Name: "add",
Offset: addMultiOff,
Parameters: []manifest.Parameter{
manifest.NewParameter("addend1", smartcontract.IntegerType),
manifest.NewParameter("addend2", smartcontract.IntegerType),
manifest.NewParameter("addend3", smartcontract.IntegerType),
},
ReturnType: smartcontract.IntegerType,
},
{ {
Name: "ret7", Name: "ret7",
Offset: ret7Off, Offset: ret7Off,
@ -731,6 +743,7 @@ func TestContractCall(t *testing.T) {
addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)}) addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)})
t.Run("Good", func(t *testing.T) { t.Run("Good", func(t *testing.T) {
t.Run("2 arguments", func(t *testing.T) {
loadScript(ic, currScript, 42) loadScript(ic, currScript, 42)
ic.VM.Estack().PushVal(addArgs) ic.VM.Estack().PushVal(addArgs)
ic.VM.Estack().PushVal(callflag.All) ic.VM.Estack().PushVal(callflag.All)
@ -742,6 +755,20 @@ func TestContractCall(t *testing.T) {
require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value()) require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value()) require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
}) })
t.Run("3 arguments", func(t *testing.T) {
loadScript(ic, currScript, 42)
ic.VM.Estack().PushVal(stackitem.NewArray(
append(addArgs.Value().([]stackitem.Item), stackitem.Make(3))))
ic.VM.Estack().PushVal(callflag.All)
ic.VM.Estack().PushVal("add")
ic.VM.Estack().PushVal(h.BytesBE())
require.NoError(t, contract.Call(ic))
require.NoError(t, ic.VM.Run())
require.Equal(t, 2, ic.VM.Estack().Len())
require.Equal(t, big.NewInt(6), ic.VM.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
})
})
t.Run("CallExInvalidFlag", func(t *testing.T) { t.Run("CallExInvalidFlag", func(t *testing.T) {
loadScript(ic, currScript, 42) loadScript(ic, currScript, 42)
@ -778,6 +805,10 @@ func TestContractCall(t *testing.T) {
t.Run("Arguments", runInvalid(1, "add", h.BytesBE())) t.Run("Arguments", runInvalid(1, "add", h.BytesBE()))
t.Run("NotEnoughArguments", runInvalid( t.Run("NotEnoughArguments", runInvalid(
stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE())) stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE()))
t.Run("TooMuchArguments", runInvalid(
stackitem.NewArray([]stackitem.Item{
stackitem.Make(1), stackitem.Make(2), stackitem.Make(3), stackitem.Make(4)}),
"add", h.BytesBE()))
}) })
t.Run("ReturnValues", func(t *testing.T) { t.Run("ReturnValues", func(t *testing.T) {

View file

@ -379,7 +379,7 @@ func (m *Management) setMinimumDeploymentFee(ic *interop.Context, args []stackit
} }
func (m *Management) callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) { func (m *Management) callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) {
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy) md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy, 1)
if md != nil { if md != nil {
err := contract.CallFromNative(ic, m.Hash, cs, manifest.MethodDeploy, err := contract.CallFromNative(ic, m.Hash, cs, manifest.MethodDeploy,
[]stackitem.Item{stackitem.NewBool(isUpdate)}, false) []stackitem.Item{stackitem.NewBool(isUpdate)}, false)

View file

@ -227,9 +227,8 @@ func TestContractDeploy(t *testing.T) {
Offset: 0, Offset: 0,
Parameters: []manifest.Parameter{ Parameters: []manifest.Parameter{
manifest.NewParameter("isUpdate", smartcontract.BoolType), manifest.NewParameter("isUpdate", smartcontract.BoolType),
manifest.NewParameter("param", smartcontract.IntegerType),
}, },
ReturnType: smartcontract.VoidType, ReturnType: smartcontract.ArrayType,
}, },
} }
nefD, err := nef.NewFile(deployScript) nefD, err := nef.NewFile(deployScript)

View file

@ -76,9 +76,9 @@ func DefaultManifest(name string) *Manifest {
} }
// GetMethod returns methods with the specified name. // GetMethod returns methods with the specified name.
func (a *ABI) GetMethod(name string) *Method { func (a *ABI) GetMethod(name string, paramCount int) *Method {
for i := range a.Methods { for i := range a.Methods {
if a.Methods[i].Name == name { if a.Methods[i].Name == name && (paramCount == -1 || len(a.Methods[i].Parameters) == paramCount) {
return &a.Methods[i] return &a.Methods[i]
} }
} }

View file

@ -40,15 +40,12 @@ func Check(m *manifest.Manifest, standards ...string) error {
func Comply(m, st *manifest.Manifest) error { func Comply(m, st *manifest.Manifest) error {
for _, stm := range st.ABI.Methods { for _, stm := range st.ABI.Methods {
name := stm.Name name := stm.Name
md := m.ABI.GetMethod(name) md := m.ABI.GetMethod(name, len(stm.Parameters))
if md == nil { if md == nil {
return fmt.Errorf("%w: '%s'", ErrMethodMissing, name) return fmt.Errorf("%w: '%s' with %d parameters", ErrMethodMissing, name, len(stm.Parameters))
} else if stm.ReturnType != md.ReturnType { } else if stm.ReturnType != md.ReturnType {
return fmt.Errorf("%w: '%s' (expected %s, got %s)", ErrInvalidReturnType, return fmt.Errorf("%w: '%s' (expected %s, got %s)", ErrInvalidReturnType,
name, stm.ReturnType, md.ReturnType) name, stm.ReturnType, md.ReturnType)
} else if len(stm.Parameters) != len(md.Parameters) {
return fmt.Errorf("%w: '%s' (expected %d, got %d)", ErrInvalidParameterCount,
name, len(stm.Parameters), len(md.Parameters))
} }
for i := range stm.Parameters { for i := range stm.Parameters {
if stm.Parameters[i].Type != md.Parameters[i].Type { if stm.Parameters[i].Type != md.Parameters[i].Type {

View file

@ -37,14 +37,14 @@ func fooMethodBarEvent() *manifest.Manifest {
func TestComplyMissingMethod(t *testing.T) { func TestComplyMissingMethod(t *testing.T) {
m := fooMethodBarEvent() m := fooMethodBarEvent()
m.ABI.GetMethod("foo").Name = "notafoo" m.ABI.GetMethod("foo", -1).Name = "notafoo"
err := Comply(m, fooMethodBarEvent()) err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrMethodMissing)) require.True(t, errors.Is(err, ErrMethodMissing))
} }
func TestComplyInvalidReturnType(t *testing.T) { func TestComplyInvalidReturnType(t *testing.T) {
m := fooMethodBarEvent() m := fooMethodBarEvent()
m.ABI.GetMethod("foo").ReturnType = smartcontract.VoidType m.ABI.GetMethod("foo", -1).ReturnType = smartcontract.VoidType
err := Comply(m, fooMethodBarEvent()) err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrInvalidReturnType)) require.True(t, errors.Is(err, ErrInvalidReturnType))
} }
@ -52,10 +52,10 @@ func TestComplyInvalidReturnType(t *testing.T) {
func TestComplyMethodParameterCount(t *testing.T) { func TestComplyMethodParameterCount(t *testing.T) {
t.Run("Method", func(t *testing.T) { t.Run("Method", func(t *testing.T) {
m := fooMethodBarEvent() m := fooMethodBarEvent()
f := m.ABI.GetMethod("foo") f := m.ABI.GetMethod("foo", -1)
f.Parameters = append(f.Parameters, manifest.Parameter{Type: smartcontract.BoolType}) f.Parameters = append(f.Parameters, manifest.Parameter{Type: smartcontract.BoolType})
err := Comply(m, fooMethodBarEvent()) err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrInvalidParameterCount)) require.True(t, errors.Is(err, ErrMethodMissing))
}) })
t.Run("Event", func(t *testing.T) { t.Run("Event", func(t *testing.T) {
m := fooMethodBarEvent() m := fooMethodBarEvent()
@ -69,7 +69,7 @@ func TestComplyMethodParameterCount(t *testing.T) {
func TestComplyParameterType(t *testing.T) { func TestComplyParameterType(t *testing.T) {
t.Run("Method", func(t *testing.T) { t.Run("Method", func(t *testing.T) {
m := fooMethodBarEvent() m := fooMethodBarEvent()
m.ABI.GetMethod("foo").Parameters[0].Type = smartcontract.InteropInterfaceType m.ABI.GetMethod("foo", -1).Parameters[0].Type = smartcontract.InteropInterfaceType
err := Comply(m, fooMethodBarEvent()) err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrInvalidParameterType)) require.True(t, errors.Is(err, ErrInvalidParameterType))
}) })
@ -90,7 +90,7 @@ func TestMissingEvent(t *testing.T) {
func TestSafeFlag(t *testing.T) { func TestSafeFlag(t *testing.T) {
m := fooMethodBarEvent() m := fooMethodBarEvent()
m.ABI.GetMethod("foo").Safe = false m.ABI.GetMethod("foo", -1).Safe = false
err := Comply(m, fooMethodBarEvent()) err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrSafeMethodMismatch)) require.True(t, errors.Is(err, ErrSafeMethodMismatch))
} }

View file

@ -389,25 +389,25 @@ func handleRun(c *ishell.Context) {
runCurrent = c.Args[0] != "_" runCurrent = c.Args[0] != "_"
) )
params, err = parseArgs(c.Args[1:])
if err != nil {
c.Err(err)
return
}
if runCurrent { if runCurrent {
md := m.ABI.GetMethod(c.Args[0]) md := m.ABI.GetMethod(c.Args[0], len(params))
if md == nil { if md == nil {
c.Err(fmt.Errorf("%w: method not found", ErrInvalidParameter)) c.Err(fmt.Errorf("%w: method not found", ErrInvalidParameter))
return return
} }
offset = md.Offset offset = md.Offset
} }
params, err = parseArgs(c.Args[1:])
if err != nil {
c.Err(err)
return
}
for i := len(params) - 1; i >= 0; i-- { for i := len(params) - 1; i >= 0; i-- {
v.Estack().PushVal(params[i]) v.Estack().PushVal(params[i])
} }
if runCurrent { if runCurrent {
v.Jump(v.Context(), offset) v.Jump(v.Context(), offset)
if initMD := m.ABI.GetMethod(manifest.MethodInit); initMD != nil { if initMD := m.ABI.GetMethod(manifest.MethodInit, 0); initMD != nil {
v.Call(v.Context(), initMD.Offset) v.Call(v.Context(), initMD.Offset)
} }
} }