Merge pull request #1689 from nspcc-dev/overload
core: allow to overload contract methods
This commit is contained in:
commit
f1792b32b9
10 changed files with 72 additions and 43 deletions
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/native/nativenames"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/wallet"
|
||||
|
@ -419,9 +420,9 @@ func importDeployed(ctx *cli.Context) error {
|
|||
if err != nil {
|
||||
return cli.NewExitError(fmt.Errorf("can't fetch contract info: %w", err), 1)
|
||||
}
|
||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify)
|
||||
if md == nil {
|
||||
return cli.NewExitError("contract has no `verify` method", 1)
|
||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
|
||||
if md == nil || md.ReturnType != smartcontract.BoolType {
|
||||
return cli.NewExitError("contract has no `verify` method with boolean return", 1)
|
||||
}
|
||||
acc.Address = address.Uint160ToString(cs.Hash)
|
||||
acc.Contract.Script = cs.NEF.Script
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||
|
@ -1667,11 +1668,11 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160,
|
|||
if err != nil {
|
||||
return ErrUnknownVerificationContract
|
||||
}
|
||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify)
|
||||
if md == nil {
|
||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
|
||||
if md == nil || md.ReturnType != smartcontract.BoolType {
|
||||
return ErrInvalidVerificationContract
|
||||
}
|
||||
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit)
|
||||
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
|
||||
v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates)
|
||||
v.Context().NEF = &cs.NEF
|
||||
v.Jump(v.Context(), md.Offset)
|
||||
|
|
|
@ -55,7 +55,7 @@ func Call(ic *interop.Context) error {
|
|||
if strings.HasPrefix(method, "_") {
|
||||
return errors.New("invalid method name (starts with '_')")
|
||||
}
|
||||
md := cs.Manifest.ABI.GetMethod(method)
|
||||
md := cs.Manifest.ABI.GetMethod(method, len(args))
|
||||
if md == nil {
|
||||
return errors.New("method not found")
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ func Call(ic *interop.Context) error {
|
|||
|
||||
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
|
||||
hasReturn bool, args []stackitem.Item) error {
|
||||
md := cs.Manifest.ABI.GetMethod(name)
|
||||
md := cs.Manifest.ABI.GetMethod(name, len(args))
|
||||
if md.Safe {
|
||||
f &^= callflag.WriteStates
|
||||
} else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() {
|
||||
|
@ -85,7 +85,7 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl
|
|||
// callExFromNative calls a contract with flags using provided calling hash.
|
||||
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
|
||||
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error {
|
||||
md := cs.Manifest.ABI.GetMethod(name)
|
||||
md := cs.Manifest.ABI.GetMethod(name, len(args))
|
||||
if md == nil {
|
||||
return fmt.Errorf("method '%s' not found", name)
|
||||
}
|
||||
|
@ -119,7 +119,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra
|
|||
ic.VM.Context().RetCount = 0
|
||||
}
|
||||
|
||||
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
|
||||
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
|
||||
if md != nil {
|
||||
ic.VM.Call(ic.VM.Context(), md.Offset)
|
||||
}
|
||||
|
|
|
@ -461,6 +461,8 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
|
|||
emit.Opcodes(w.BinWriter, opcode.ABORT)
|
||||
addOff := w.Len()
|
||||
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.RET)
|
||||
addMultiOff := w.Len()
|
||||
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.ADD, opcode.RET)
|
||||
ret7Off := w.Len()
|
||||
emit.Opcodes(w.BinWriter, opcode.PUSH7, opcode.RET)
|
||||
dropOff := w.Len()
|
||||
|
@ -533,6 +535,16 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
|
|||
},
|
||||
ReturnType: smartcontract.IntegerType,
|
||||
},
|
||||
{
|
||||
Name: "add",
|
||||
Offset: addMultiOff,
|
||||
Parameters: []manifest.Parameter{
|
||||
manifest.NewParameter("addend1", smartcontract.IntegerType),
|
||||
manifest.NewParameter("addend2", smartcontract.IntegerType),
|
||||
manifest.NewParameter("addend3", smartcontract.IntegerType),
|
||||
},
|
||||
ReturnType: smartcontract.IntegerType,
|
||||
},
|
||||
{
|
||||
Name: "ret7",
|
||||
Offset: ret7Off,
|
||||
|
@ -731,6 +743,7 @@ func TestContractCall(t *testing.T) {
|
|||
|
||||
addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)})
|
||||
t.Run("Good", func(t *testing.T) {
|
||||
t.Run("2 arguments", func(t *testing.T) {
|
||||
loadScript(ic, currScript, 42)
|
||||
ic.VM.Estack().PushVal(addArgs)
|
||||
ic.VM.Estack().PushVal(callflag.All)
|
||||
|
@ -742,6 +755,20 @@ func TestContractCall(t *testing.T) {
|
|||
require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value())
|
||||
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
|
||||
})
|
||||
t.Run("3 arguments", func(t *testing.T) {
|
||||
loadScript(ic, currScript, 42)
|
||||
ic.VM.Estack().PushVal(stackitem.NewArray(
|
||||
append(addArgs.Value().([]stackitem.Item), stackitem.Make(3))))
|
||||
ic.VM.Estack().PushVal(callflag.All)
|
||||
ic.VM.Estack().PushVal("add")
|
||||
ic.VM.Estack().PushVal(h.BytesBE())
|
||||
require.NoError(t, contract.Call(ic))
|
||||
require.NoError(t, ic.VM.Run())
|
||||
require.Equal(t, 2, ic.VM.Estack().Len())
|
||||
require.Equal(t, big.NewInt(6), ic.VM.Estack().Pop().Value())
|
||||
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CallExInvalidFlag", func(t *testing.T) {
|
||||
loadScript(ic, currScript, 42)
|
||||
|
@ -778,6 +805,10 @@ func TestContractCall(t *testing.T) {
|
|||
t.Run("Arguments", runInvalid(1, "add", h.BytesBE()))
|
||||
t.Run("NotEnoughArguments", runInvalid(
|
||||
stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE()))
|
||||
t.Run("TooMuchArguments", runInvalid(
|
||||
stackitem.NewArray([]stackitem.Item{
|
||||
stackitem.Make(1), stackitem.Make(2), stackitem.Make(3), stackitem.Make(4)}),
|
||||
"add", h.BytesBE()))
|
||||
})
|
||||
|
||||
t.Run("ReturnValues", func(t *testing.T) {
|
||||
|
|
|
@ -379,7 +379,7 @@ func (m *Management) setMinimumDeploymentFee(ic *interop.Context, args []stackit
|
|||
}
|
||||
|
||||
func (m *Management) callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) {
|
||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
|
||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy, 1)
|
||||
if md != nil {
|
||||
err := contract.CallFromNative(ic, m.Hash, cs, manifest.MethodDeploy,
|
||||
[]stackitem.Item{stackitem.NewBool(isUpdate)}, false)
|
||||
|
|
|
@ -227,9 +227,8 @@ func TestContractDeploy(t *testing.T) {
|
|||
Offset: 0,
|
||||
Parameters: []manifest.Parameter{
|
||||
manifest.NewParameter("isUpdate", smartcontract.BoolType),
|
||||
manifest.NewParameter("param", smartcontract.IntegerType),
|
||||
},
|
||||
ReturnType: smartcontract.VoidType,
|
||||
ReturnType: smartcontract.ArrayType,
|
||||
},
|
||||
}
|
||||
nefD, err := nef.NewFile(deployScript)
|
||||
|
|
|
@ -76,9 +76,9 @@ func DefaultManifest(name string) *Manifest {
|
|||
}
|
||||
|
||||
// GetMethod returns methods with the specified name.
|
||||
func (a *ABI) GetMethod(name string) *Method {
|
||||
func (a *ABI) GetMethod(name string, paramCount int) *Method {
|
||||
for i := range a.Methods {
|
||||
if a.Methods[i].Name == name {
|
||||
if a.Methods[i].Name == name && (paramCount == -1 || len(a.Methods[i].Parameters) == paramCount) {
|
||||
return &a.Methods[i]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,15 +40,12 @@ func Check(m *manifest.Manifest, standards ...string) error {
|
|||
func Comply(m, st *manifest.Manifest) error {
|
||||
for _, stm := range st.ABI.Methods {
|
||||
name := stm.Name
|
||||
md := m.ABI.GetMethod(name)
|
||||
md := m.ABI.GetMethod(name, len(stm.Parameters))
|
||||
if md == nil {
|
||||
return fmt.Errorf("%w: '%s'", ErrMethodMissing, name)
|
||||
return fmt.Errorf("%w: '%s' with %d parameters", ErrMethodMissing, name, len(stm.Parameters))
|
||||
} else if stm.ReturnType != md.ReturnType {
|
||||
return fmt.Errorf("%w: '%s' (expected %s, got %s)", ErrInvalidReturnType,
|
||||
name, stm.ReturnType, md.ReturnType)
|
||||
} else if len(stm.Parameters) != len(md.Parameters) {
|
||||
return fmt.Errorf("%w: '%s' (expected %d, got %d)", ErrInvalidParameterCount,
|
||||
name, len(stm.Parameters), len(md.Parameters))
|
||||
}
|
||||
for i := range stm.Parameters {
|
||||
if stm.Parameters[i].Type != md.Parameters[i].Type {
|
||||
|
|
|
@ -37,14 +37,14 @@ func fooMethodBarEvent() *manifest.Manifest {
|
|||
|
||||
func TestComplyMissingMethod(t *testing.T) {
|
||||
m := fooMethodBarEvent()
|
||||
m.ABI.GetMethod("foo").Name = "notafoo"
|
||||
m.ABI.GetMethod("foo", -1).Name = "notafoo"
|
||||
err := Comply(m, fooMethodBarEvent())
|
||||
require.True(t, errors.Is(err, ErrMethodMissing))
|
||||
}
|
||||
|
||||
func TestComplyInvalidReturnType(t *testing.T) {
|
||||
m := fooMethodBarEvent()
|
||||
m.ABI.GetMethod("foo").ReturnType = smartcontract.VoidType
|
||||
m.ABI.GetMethod("foo", -1).ReturnType = smartcontract.VoidType
|
||||
err := Comply(m, fooMethodBarEvent())
|
||||
require.True(t, errors.Is(err, ErrInvalidReturnType))
|
||||
}
|
||||
|
@ -52,10 +52,10 @@ func TestComplyInvalidReturnType(t *testing.T) {
|
|||
func TestComplyMethodParameterCount(t *testing.T) {
|
||||
t.Run("Method", func(t *testing.T) {
|
||||
m := fooMethodBarEvent()
|
||||
f := m.ABI.GetMethod("foo")
|
||||
f := m.ABI.GetMethod("foo", -1)
|
||||
f.Parameters = append(f.Parameters, manifest.Parameter{Type: smartcontract.BoolType})
|
||||
err := Comply(m, fooMethodBarEvent())
|
||||
require.True(t, errors.Is(err, ErrInvalidParameterCount))
|
||||
require.True(t, errors.Is(err, ErrMethodMissing))
|
||||
})
|
||||
t.Run("Event", func(t *testing.T) {
|
||||
m := fooMethodBarEvent()
|
||||
|
@ -69,7 +69,7 @@ func TestComplyMethodParameterCount(t *testing.T) {
|
|||
func TestComplyParameterType(t *testing.T) {
|
||||
t.Run("Method", func(t *testing.T) {
|
||||
m := fooMethodBarEvent()
|
||||
m.ABI.GetMethod("foo").Parameters[0].Type = smartcontract.InteropInterfaceType
|
||||
m.ABI.GetMethod("foo", -1).Parameters[0].Type = smartcontract.InteropInterfaceType
|
||||
err := Comply(m, fooMethodBarEvent())
|
||||
require.True(t, errors.Is(err, ErrInvalidParameterType))
|
||||
})
|
||||
|
@ -90,7 +90,7 @@ func TestMissingEvent(t *testing.T) {
|
|||
|
||||
func TestSafeFlag(t *testing.T) {
|
||||
m := fooMethodBarEvent()
|
||||
m.ABI.GetMethod("foo").Safe = false
|
||||
m.ABI.GetMethod("foo", -1).Safe = false
|
||||
err := Comply(m, fooMethodBarEvent())
|
||||
require.True(t, errors.Is(err, ErrSafeMethodMismatch))
|
||||
}
|
||||
|
|
|
@ -389,25 +389,25 @@ func handleRun(c *ishell.Context) {
|
|||
runCurrent = c.Args[0] != "_"
|
||||
)
|
||||
|
||||
params, err = parseArgs(c.Args[1:])
|
||||
if err != nil {
|
||||
c.Err(err)
|
||||
return
|
||||
}
|
||||
if runCurrent {
|
||||
md := m.ABI.GetMethod(c.Args[0])
|
||||
md := m.ABI.GetMethod(c.Args[0], len(params))
|
||||
if md == nil {
|
||||
c.Err(fmt.Errorf("%w: method not found", ErrInvalidParameter))
|
||||
return
|
||||
}
|
||||
offset = md.Offset
|
||||
}
|
||||
params, err = parseArgs(c.Args[1:])
|
||||
if err != nil {
|
||||
c.Err(err)
|
||||
return
|
||||
}
|
||||
for i := len(params) - 1; i >= 0; i-- {
|
||||
v.Estack().PushVal(params[i])
|
||||
}
|
||||
if runCurrent {
|
||||
v.Jump(v.Context(), offset)
|
||||
if initMD := m.ABI.GetMethod(manifest.MethodInit); initMD != nil {
|
||||
if initMD := m.ABI.GetMethod(manifest.MethodInit, 0); initMD != nil {
|
||||
v.Call(v.Context(), initMD.Offset)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue