Merge pull request #1689 from nspcc-dev/overload

core: allow to overload contract methods
This commit is contained in:
Roman Khimov 2021-01-27 15:11:54 +03:00 committed by GitHub
commit f1792b32b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 72 additions and 43 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/native/nativenames"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/wallet"
@ -419,9 +420,9 @@ func importDeployed(ctx *cli.Context) error {
if err != nil {
return cli.NewExitError(fmt.Errorf("can't fetch contract info: %w", err), 1)
}
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify)
if md == nil {
return cli.NewExitError("contract has no `verify` method", 1)
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
if md == nil || md.ReturnType != smartcontract.BoolType {
return cli.NewExitError("contract has no `verify` method with boolean return", 1)
}
acc.Address = address.Uint160ToString(cs.Hash)
acc.Contract.Script = cs.NEF.Script

View file

@ -28,6 +28,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
@ -1667,11 +1668,11 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160,
if err != nil {
return ErrUnknownVerificationContract
}
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify)
if md == nil {
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
if md == nil || md.ReturnType != smartcontract.BoolType {
return ErrInvalidVerificationContract
}
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit)
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates)
v.Context().NEF = &cs.NEF
v.Jump(v.Context(), md.Offset)

View file

@ -55,7 +55,7 @@ func Call(ic *interop.Context) error {
if strings.HasPrefix(method, "_") {
return errors.New("invalid method name (starts with '_')")
}
md := cs.Manifest.ABI.GetMethod(method)
md := cs.Manifest.ABI.GetMethod(method, len(args))
if md == nil {
return errors.New("method not found")
}
@ -68,7 +68,7 @@ func Call(ic *interop.Context) error {
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
hasReturn bool, args []stackitem.Item) error {
md := cs.Manifest.ABI.GetMethod(name)
md := cs.Manifest.ABI.GetMethod(name, len(args))
if md.Safe {
f &^= callflag.WriteStates
} else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() {
@ -85,7 +85,7 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl
// callExFromNative calls a contract with flags using provided calling hash.
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error {
md := cs.Manifest.ABI.GetMethod(name)
md := cs.Manifest.ABI.GetMethod(name, len(args))
if md == nil {
return fmt.Errorf("method '%s' not found", name)
}
@ -119,7 +119,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra
ic.VM.Context().RetCount = 0
}
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
if md != nil {
ic.VM.Call(ic.VM.Context(), md.Offset)
}

View file

@ -461,6 +461,8 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
emit.Opcodes(w.BinWriter, opcode.ABORT)
addOff := w.Len()
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.RET)
addMultiOff := w.Len()
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.ADD, opcode.RET)
ret7Off := w.Len()
emit.Opcodes(w.BinWriter, opcode.PUSH7, opcode.RET)
dropOff := w.Len()
@ -533,6 +535,16 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
},
ReturnType: smartcontract.IntegerType,
},
{
Name: "add",
Offset: addMultiOff,
Parameters: []manifest.Parameter{
manifest.NewParameter("addend1", smartcontract.IntegerType),
manifest.NewParameter("addend2", smartcontract.IntegerType),
manifest.NewParameter("addend3", smartcontract.IntegerType),
},
ReturnType: smartcontract.IntegerType,
},
{
Name: "ret7",
Offset: ret7Off,
@ -731,16 +743,31 @@ func TestContractCall(t *testing.T) {
addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)})
t.Run("Good", func(t *testing.T) {
loadScript(ic, currScript, 42)
ic.VM.Estack().PushVal(addArgs)
ic.VM.Estack().PushVal(callflag.All)
ic.VM.Estack().PushVal("add")
ic.VM.Estack().PushVal(h.BytesBE())
require.NoError(t, contract.Call(ic))
require.NoError(t, ic.VM.Run())
require.Equal(t, 2, ic.VM.Estack().Len())
require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
t.Run("2 arguments", func(t *testing.T) {
loadScript(ic, currScript, 42)
ic.VM.Estack().PushVal(addArgs)
ic.VM.Estack().PushVal(callflag.All)
ic.VM.Estack().PushVal("add")
ic.VM.Estack().PushVal(h.BytesBE())
require.NoError(t, contract.Call(ic))
require.NoError(t, ic.VM.Run())
require.Equal(t, 2, ic.VM.Estack().Len())
require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
})
t.Run("3 arguments", func(t *testing.T) {
loadScript(ic, currScript, 42)
ic.VM.Estack().PushVal(stackitem.NewArray(
append(addArgs.Value().([]stackitem.Item), stackitem.Make(3))))
ic.VM.Estack().PushVal(callflag.All)
ic.VM.Estack().PushVal("add")
ic.VM.Estack().PushVal(h.BytesBE())
require.NoError(t, contract.Call(ic))
require.NoError(t, ic.VM.Run())
require.Equal(t, 2, ic.VM.Estack().Len())
require.Equal(t, big.NewInt(6), ic.VM.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
})
})
t.Run("CallExInvalidFlag", func(t *testing.T) {
@ -778,6 +805,10 @@ func TestContractCall(t *testing.T) {
t.Run("Arguments", runInvalid(1, "add", h.BytesBE()))
t.Run("NotEnoughArguments", runInvalid(
stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE()))
t.Run("TooMuchArguments", runInvalid(
stackitem.NewArray([]stackitem.Item{
stackitem.Make(1), stackitem.Make(2), stackitem.Make(3), stackitem.Make(4)}),
"add", h.BytesBE()))
})
t.Run("ReturnValues", func(t *testing.T) {

View file

@ -379,7 +379,7 @@ func (m *Management) setMinimumDeploymentFee(ic *interop.Context, args []stackit
}
func (m *Management) callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) {
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy, 1)
if md != nil {
err := contract.CallFromNative(ic, m.Hash, cs, manifest.MethodDeploy,
[]stackitem.Item{stackitem.NewBool(isUpdate)}, false)

View file

@ -227,9 +227,8 @@ func TestContractDeploy(t *testing.T) {
Offset: 0,
Parameters: []manifest.Parameter{
manifest.NewParameter("isUpdate", smartcontract.BoolType),
manifest.NewParameter("param", smartcontract.IntegerType),
},
ReturnType: smartcontract.VoidType,
ReturnType: smartcontract.ArrayType,
},
}
nefD, err := nef.NewFile(deployScript)

View file

@ -76,9 +76,9 @@ func DefaultManifest(name string) *Manifest {
}
// GetMethod returns methods with the specified name.
func (a *ABI) GetMethod(name string) *Method {
func (a *ABI) GetMethod(name string, paramCount int) *Method {
for i := range a.Methods {
if a.Methods[i].Name == name {
if a.Methods[i].Name == name && (paramCount == -1 || len(a.Methods[i].Parameters) == paramCount) {
return &a.Methods[i]
}
}

View file

@ -40,15 +40,12 @@ func Check(m *manifest.Manifest, standards ...string) error {
func Comply(m, st *manifest.Manifest) error {
for _, stm := range st.ABI.Methods {
name := stm.Name
md := m.ABI.GetMethod(name)
md := m.ABI.GetMethod(name, len(stm.Parameters))
if md == nil {
return fmt.Errorf("%w: '%s'", ErrMethodMissing, name)
return fmt.Errorf("%w: '%s' with %d parameters", ErrMethodMissing, name, len(stm.Parameters))
} else if stm.ReturnType != md.ReturnType {
return fmt.Errorf("%w: '%s' (expected %s, got %s)", ErrInvalidReturnType,
name, stm.ReturnType, md.ReturnType)
} else if len(stm.Parameters) != len(md.Parameters) {
return fmt.Errorf("%w: '%s' (expected %d, got %d)", ErrInvalidParameterCount,
name, len(stm.Parameters), len(md.Parameters))
}
for i := range stm.Parameters {
if stm.Parameters[i].Type != md.Parameters[i].Type {

View file

@ -37,14 +37,14 @@ func fooMethodBarEvent() *manifest.Manifest {
func TestComplyMissingMethod(t *testing.T) {
m := fooMethodBarEvent()
m.ABI.GetMethod("foo").Name = "notafoo"
m.ABI.GetMethod("foo", -1).Name = "notafoo"
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrMethodMissing))
}
func TestComplyInvalidReturnType(t *testing.T) {
m := fooMethodBarEvent()
m.ABI.GetMethod("foo").ReturnType = smartcontract.VoidType
m.ABI.GetMethod("foo", -1).ReturnType = smartcontract.VoidType
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrInvalidReturnType))
}
@ -52,10 +52,10 @@ func TestComplyInvalidReturnType(t *testing.T) {
func TestComplyMethodParameterCount(t *testing.T) {
t.Run("Method", func(t *testing.T) {
m := fooMethodBarEvent()
f := m.ABI.GetMethod("foo")
f := m.ABI.GetMethod("foo", -1)
f.Parameters = append(f.Parameters, manifest.Parameter{Type: smartcontract.BoolType})
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrInvalidParameterCount))
require.True(t, errors.Is(err, ErrMethodMissing))
})
t.Run("Event", func(t *testing.T) {
m := fooMethodBarEvent()
@ -69,7 +69,7 @@ func TestComplyMethodParameterCount(t *testing.T) {
func TestComplyParameterType(t *testing.T) {
t.Run("Method", func(t *testing.T) {
m := fooMethodBarEvent()
m.ABI.GetMethod("foo").Parameters[0].Type = smartcontract.InteropInterfaceType
m.ABI.GetMethod("foo", -1).Parameters[0].Type = smartcontract.InteropInterfaceType
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrInvalidParameterType))
})
@ -90,7 +90,7 @@ func TestMissingEvent(t *testing.T) {
func TestSafeFlag(t *testing.T) {
m := fooMethodBarEvent()
m.ABI.GetMethod("foo").Safe = false
m.ABI.GetMethod("foo", -1).Safe = false
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrSafeMethodMismatch))
}

View file

@ -389,25 +389,25 @@ func handleRun(c *ishell.Context) {
runCurrent = c.Args[0] != "_"
)
params, err = parseArgs(c.Args[1:])
if err != nil {
c.Err(err)
return
}
if runCurrent {
md := m.ABI.GetMethod(c.Args[0])
md := m.ABI.GetMethod(c.Args[0], len(params))
if md == nil {
c.Err(fmt.Errorf("%w: method not found", ErrInvalidParameter))
return
}
offset = md.Offset
}
params, err = parseArgs(c.Args[1:])
if err != nil {
c.Err(err)
return
}
for i := len(params) - 1; i >= 0; i-- {
v.Estack().PushVal(params[i])
}
if runCurrent {
v.Jump(v.Context(), offset)
if initMD := m.ABI.GetMethod(manifest.MethodInit); initMD != nil {
if initMD := m.ABI.GetMethod(manifest.MethodInit, 0); initMD != nil {
v.Call(v.Context(), initMD.Offset)
}
}