mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2025-01-25 15:14:48 +00:00
core: allow to overload contract methods
Multiple methods with different parameter count can co-exist.
This commit is contained in:
parent
32e86785fa
commit
73f888f02e
10 changed files with 67 additions and 40 deletions
|
@ -419,7 +419,7 @@ func importDeployed(ctx *cli.Context) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cli.NewExitError(fmt.Errorf("can't fetch contract info: %w", err), 1)
|
return cli.NewExitError(fmt.Errorf("can't fetch contract info: %w", err), 1)
|
||||||
}
|
}
|
||||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify)
|
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
|
||||||
if md == nil {
|
if md == nil {
|
||||||
return cli.NewExitError("contract has no `verify` method", 1)
|
return cli.NewExitError("contract has no `verify` method", 1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1667,11 +1667,11 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrUnknownVerificationContract
|
return ErrUnknownVerificationContract
|
||||||
}
|
}
|
||||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify)
|
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
|
||||||
if md == nil {
|
if md == nil {
|
||||||
return ErrInvalidVerificationContract
|
return ErrInvalidVerificationContract
|
||||||
}
|
}
|
||||||
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit)
|
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
|
||||||
v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates)
|
v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates)
|
||||||
v.Context().NEF = &cs.NEF
|
v.Context().NEF = &cs.NEF
|
||||||
v.Jump(v.Context(), md.Offset)
|
v.Jump(v.Context(), md.Offset)
|
||||||
|
|
|
@ -55,7 +55,7 @@ func Call(ic *interop.Context) error {
|
||||||
if strings.HasPrefix(method, "_") {
|
if strings.HasPrefix(method, "_") {
|
||||||
return errors.New("invalid method name (starts with '_')")
|
return errors.New("invalid method name (starts with '_')")
|
||||||
}
|
}
|
||||||
md := cs.Manifest.ABI.GetMethod(method)
|
md := cs.Manifest.ABI.GetMethod(method, len(args))
|
||||||
if md == nil {
|
if md == nil {
|
||||||
return errors.New("method not found")
|
return errors.New("method not found")
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ func Call(ic *interop.Context) error {
|
||||||
|
|
||||||
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
|
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
|
||||||
hasReturn bool, args []stackitem.Item) error {
|
hasReturn bool, args []stackitem.Item) error {
|
||||||
md := cs.Manifest.ABI.GetMethod(name)
|
md := cs.Manifest.ABI.GetMethod(name, len(args))
|
||||||
if md.Safe {
|
if md.Safe {
|
||||||
f &^= callflag.WriteStates
|
f &^= callflag.WriteStates
|
||||||
} else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() {
|
} else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() {
|
||||||
|
@ -85,7 +85,7 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl
|
||||||
// callExFromNative calls a contract with flags using provided calling hash.
|
// callExFromNative calls a contract with flags using provided calling hash.
|
||||||
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
|
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
|
||||||
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error {
|
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error {
|
||||||
md := cs.Manifest.ABI.GetMethod(name)
|
md := cs.Manifest.ABI.GetMethod(name, len(args))
|
||||||
if md == nil {
|
if md == nil {
|
||||||
return fmt.Errorf("method '%s' not found", name)
|
return fmt.Errorf("method '%s' not found", name)
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra
|
||||||
ic.VM.Context().RetCount = 0
|
ic.VM.Context().RetCount = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
|
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
|
||||||
if md != nil {
|
if md != nil {
|
||||||
ic.VM.Call(ic.VM.Context(), md.Offset)
|
ic.VM.Call(ic.VM.Context(), md.Offset)
|
||||||
}
|
}
|
||||||
|
|
|
@ -461,6 +461,8 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
|
||||||
emit.Opcodes(w.BinWriter, opcode.ABORT)
|
emit.Opcodes(w.BinWriter, opcode.ABORT)
|
||||||
addOff := w.Len()
|
addOff := w.Len()
|
||||||
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.RET)
|
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.RET)
|
||||||
|
addMultiOff := w.Len()
|
||||||
|
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.ADD, opcode.RET)
|
||||||
ret7Off := w.Len()
|
ret7Off := w.Len()
|
||||||
emit.Opcodes(w.BinWriter, opcode.PUSH7, opcode.RET)
|
emit.Opcodes(w.BinWriter, opcode.PUSH7, opcode.RET)
|
||||||
dropOff := w.Len()
|
dropOff := w.Len()
|
||||||
|
@ -533,6 +535,16 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
|
||||||
},
|
},
|
||||||
ReturnType: smartcontract.IntegerType,
|
ReturnType: smartcontract.IntegerType,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "add",
|
||||||
|
Offset: addMultiOff,
|
||||||
|
Parameters: []manifest.Parameter{
|
||||||
|
manifest.NewParameter("addend1", smartcontract.IntegerType),
|
||||||
|
manifest.NewParameter("addend2", smartcontract.IntegerType),
|
||||||
|
manifest.NewParameter("addend3", smartcontract.IntegerType),
|
||||||
|
},
|
||||||
|
ReturnType: smartcontract.IntegerType,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "ret7",
|
Name: "ret7",
|
||||||
Offset: ret7Off,
|
Offset: ret7Off,
|
||||||
|
@ -731,6 +743,7 @@ func TestContractCall(t *testing.T) {
|
||||||
|
|
||||||
addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)})
|
addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)})
|
||||||
t.Run("Good", func(t *testing.T) {
|
t.Run("Good", func(t *testing.T) {
|
||||||
|
t.Run("2 arguments", func(t *testing.T) {
|
||||||
loadScript(ic, currScript, 42)
|
loadScript(ic, currScript, 42)
|
||||||
ic.VM.Estack().PushVal(addArgs)
|
ic.VM.Estack().PushVal(addArgs)
|
||||||
ic.VM.Estack().PushVal(callflag.All)
|
ic.VM.Estack().PushVal(callflag.All)
|
||||||
|
@ -742,6 +755,20 @@ func TestContractCall(t *testing.T) {
|
||||||
require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value())
|
require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value())
|
||||||
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
|
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
|
||||||
})
|
})
|
||||||
|
t.Run("3 arguments", func(t *testing.T) {
|
||||||
|
loadScript(ic, currScript, 42)
|
||||||
|
ic.VM.Estack().PushVal(stackitem.NewArray(
|
||||||
|
append(addArgs.Value().([]stackitem.Item), stackitem.Make(3))))
|
||||||
|
ic.VM.Estack().PushVal(callflag.All)
|
||||||
|
ic.VM.Estack().PushVal("add")
|
||||||
|
ic.VM.Estack().PushVal(h.BytesBE())
|
||||||
|
require.NoError(t, contract.Call(ic))
|
||||||
|
require.NoError(t, ic.VM.Run())
|
||||||
|
require.Equal(t, 2, ic.VM.Estack().Len())
|
||||||
|
require.Equal(t, big.NewInt(6), ic.VM.Estack().Pop().Value())
|
||||||
|
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("CallExInvalidFlag", func(t *testing.T) {
|
t.Run("CallExInvalidFlag", func(t *testing.T) {
|
||||||
loadScript(ic, currScript, 42)
|
loadScript(ic, currScript, 42)
|
||||||
|
@ -778,6 +805,10 @@ func TestContractCall(t *testing.T) {
|
||||||
t.Run("Arguments", runInvalid(1, "add", h.BytesBE()))
|
t.Run("Arguments", runInvalid(1, "add", h.BytesBE()))
|
||||||
t.Run("NotEnoughArguments", runInvalid(
|
t.Run("NotEnoughArguments", runInvalid(
|
||||||
stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE()))
|
stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE()))
|
||||||
|
t.Run("TooMuchArguments", runInvalid(
|
||||||
|
stackitem.NewArray([]stackitem.Item{
|
||||||
|
stackitem.Make(1), stackitem.Make(2), stackitem.Make(3), stackitem.Make(4)}),
|
||||||
|
"add", h.BytesBE()))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ReturnValues", func(t *testing.T) {
|
t.Run("ReturnValues", func(t *testing.T) {
|
||||||
|
|
|
@ -379,7 +379,7 @@ func (m *Management) setMinimumDeploymentFee(ic *interop.Context, args []stackit
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Management) callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) {
|
func (m *Management) callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) {
|
||||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
|
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy, 1)
|
||||||
if md != nil {
|
if md != nil {
|
||||||
err := contract.CallFromNative(ic, m.Hash, cs, manifest.MethodDeploy,
|
err := contract.CallFromNative(ic, m.Hash, cs, manifest.MethodDeploy,
|
||||||
[]stackitem.Item{stackitem.NewBool(isUpdate)}, false)
|
[]stackitem.Item{stackitem.NewBool(isUpdate)}, false)
|
||||||
|
|
|
@ -227,9 +227,8 @@ func TestContractDeploy(t *testing.T) {
|
||||||
Offset: 0,
|
Offset: 0,
|
||||||
Parameters: []manifest.Parameter{
|
Parameters: []manifest.Parameter{
|
||||||
manifest.NewParameter("isUpdate", smartcontract.BoolType),
|
manifest.NewParameter("isUpdate", smartcontract.BoolType),
|
||||||
manifest.NewParameter("param", smartcontract.IntegerType),
|
|
||||||
},
|
},
|
||||||
ReturnType: smartcontract.VoidType,
|
ReturnType: smartcontract.ArrayType,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
nefD, err := nef.NewFile(deployScript)
|
nefD, err := nef.NewFile(deployScript)
|
||||||
|
|
|
@ -76,9 +76,9 @@ func DefaultManifest(name string) *Manifest {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMethod returns methods with the specified name.
|
// GetMethod returns methods with the specified name.
|
||||||
func (a *ABI) GetMethod(name string) *Method {
|
func (a *ABI) GetMethod(name string, paramCount int) *Method {
|
||||||
for i := range a.Methods {
|
for i := range a.Methods {
|
||||||
if a.Methods[i].Name == name {
|
if a.Methods[i].Name == name && (paramCount == -1 || len(a.Methods[i].Parameters) == paramCount) {
|
||||||
return &a.Methods[i]
|
return &a.Methods[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,15 +40,12 @@ func Check(m *manifest.Manifest, standards ...string) error {
|
||||||
func Comply(m, st *manifest.Manifest) error {
|
func Comply(m, st *manifest.Manifest) error {
|
||||||
for _, stm := range st.ABI.Methods {
|
for _, stm := range st.ABI.Methods {
|
||||||
name := stm.Name
|
name := stm.Name
|
||||||
md := m.ABI.GetMethod(name)
|
md := m.ABI.GetMethod(name, len(stm.Parameters))
|
||||||
if md == nil {
|
if md == nil {
|
||||||
return fmt.Errorf("%w: '%s'", ErrMethodMissing, name)
|
return fmt.Errorf("%w: '%s' with %d parameters", ErrMethodMissing, name, len(stm.Parameters))
|
||||||
} else if stm.ReturnType != md.ReturnType {
|
} else if stm.ReturnType != md.ReturnType {
|
||||||
return fmt.Errorf("%w: '%s' (expected %s, got %s)", ErrInvalidReturnType,
|
return fmt.Errorf("%w: '%s' (expected %s, got %s)", ErrInvalidReturnType,
|
||||||
name, stm.ReturnType, md.ReturnType)
|
name, stm.ReturnType, md.ReturnType)
|
||||||
} else if len(stm.Parameters) != len(md.Parameters) {
|
|
||||||
return fmt.Errorf("%w: '%s' (expected %d, got %d)", ErrInvalidParameterCount,
|
|
||||||
name, len(stm.Parameters), len(md.Parameters))
|
|
||||||
}
|
}
|
||||||
for i := range stm.Parameters {
|
for i := range stm.Parameters {
|
||||||
if stm.Parameters[i].Type != md.Parameters[i].Type {
|
if stm.Parameters[i].Type != md.Parameters[i].Type {
|
||||||
|
|
|
@ -37,14 +37,14 @@ func fooMethodBarEvent() *manifest.Manifest {
|
||||||
|
|
||||||
func TestComplyMissingMethod(t *testing.T) {
|
func TestComplyMissingMethod(t *testing.T) {
|
||||||
m := fooMethodBarEvent()
|
m := fooMethodBarEvent()
|
||||||
m.ABI.GetMethod("foo").Name = "notafoo"
|
m.ABI.GetMethod("foo", -1).Name = "notafoo"
|
||||||
err := Comply(m, fooMethodBarEvent())
|
err := Comply(m, fooMethodBarEvent())
|
||||||
require.True(t, errors.Is(err, ErrMethodMissing))
|
require.True(t, errors.Is(err, ErrMethodMissing))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestComplyInvalidReturnType(t *testing.T) {
|
func TestComplyInvalidReturnType(t *testing.T) {
|
||||||
m := fooMethodBarEvent()
|
m := fooMethodBarEvent()
|
||||||
m.ABI.GetMethod("foo").ReturnType = smartcontract.VoidType
|
m.ABI.GetMethod("foo", -1).ReturnType = smartcontract.VoidType
|
||||||
err := Comply(m, fooMethodBarEvent())
|
err := Comply(m, fooMethodBarEvent())
|
||||||
require.True(t, errors.Is(err, ErrInvalidReturnType))
|
require.True(t, errors.Is(err, ErrInvalidReturnType))
|
||||||
}
|
}
|
||||||
|
@ -52,10 +52,10 @@ func TestComplyInvalidReturnType(t *testing.T) {
|
||||||
func TestComplyMethodParameterCount(t *testing.T) {
|
func TestComplyMethodParameterCount(t *testing.T) {
|
||||||
t.Run("Method", func(t *testing.T) {
|
t.Run("Method", func(t *testing.T) {
|
||||||
m := fooMethodBarEvent()
|
m := fooMethodBarEvent()
|
||||||
f := m.ABI.GetMethod("foo")
|
f := m.ABI.GetMethod("foo", -1)
|
||||||
f.Parameters = append(f.Parameters, manifest.Parameter{Type: smartcontract.BoolType})
|
f.Parameters = append(f.Parameters, manifest.Parameter{Type: smartcontract.BoolType})
|
||||||
err := Comply(m, fooMethodBarEvent())
|
err := Comply(m, fooMethodBarEvent())
|
||||||
require.True(t, errors.Is(err, ErrInvalidParameterCount))
|
require.True(t, errors.Is(err, ErrMethodMissing))
|
||||||
})
|
})
|
||||||
t.Run("Event", func(t *testing.T) {
|
t.Run("Event", func(t *testing.T) {
|
||||||
m := fooMethodBarEvent()
|
m := fooMethodBarEvent()
|
||||||
|
@ -69,7 +69,7 @@ func TestComplyMethodParameterCount(t *testing.T) {
|
||||||
func TestComplyParameterType(t *testing.T) {
|
func TestComplyParameterType(t *testing.T) {
|
||||||
t.Run("Method", func(t *testing.T) {
|
t.Run("Method", func(t *testing.T) {
|
||||||
m := fooMethodBarEvent()
|
m := fooMethodBarEvent()
|
||||||
m.ABI.GetMethod("foo").Parameters[0].Type = smartcontract.InteropInterfaceType
|
m.ABI.GetMethod("foo", -1).Parameters[0].Type = smartcontract.InteropInterfaceType
|
||||||
err := Comply(m, fooMethodBarEvent())
|
err := Comply(m, fooMethodBarEvent())
|
||||||
require.True(t, errors.Is(err, ErrInvalidParameterType))
|
require.True(t, errors.Is(err, ErrInvalidParameterType))
|
||||||
})
|
})
|
||||||
|
@ -90,7 +90,7 @@ func TestMissingEvent(t *testing.T) {
|
||||||
|
|
||||||
func TestSafeFlag(t *testing.T) {
|
func TestSafeFlag(t *testing.T) {
|
||||||
m := fooMethodBarEvent()
|
m := fooMethodBarEvent()
|
||||||
m.ABI.GetMethod("foo").Safe = false
|
m.ABI.GetMethod("foo", -1).Safe = false
|
||||||
err := Comply(m, fooMethodBarEvent())
|
err := Comply(m, fooMethodBarEvent())
|
||||||
require.True(t, errors.Is(err, ErrSafeMethodMismatch))
|
require.True(t, errors.Is(err, ErrSafeMethodMismatch))
|
||||||
}
|
}
|
||||||
|
|
|
@ -389,25 +389,25 @@ func handleRun(c *ishell.Context) {
|
||||||
runCurrent = c.Args[0] != "_"
|
runCurrent = c.Args[0] != "_"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
params, err = parseArgs(c.Args[1:])
|
||||||
|
if err != nil {
|
||||||
|
c.Err(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
if runCurrent {
|
if runCurrent {
|
||||||
md := m.ABI.GetMethod(c.Args[0])
|
md := m.ABI.GetMethod(c.Args[0], len(params))
|
||||||
if md == nil {
|
if md == nil {
|
||||||
c.Err(fmt.Errorf("%w: method not found", ErrInvalidParameter))
|
c.Err(fmt.Errorf("%w: method not found", ErrInvalidParameter))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
offset = md.Offset
|
offset = md.Offset
|
||||||
}
|
}
|
||||||
params, err = parseArgs(c.Args[1:])
|
|
||||||
if err != nil {
|
|
||||||
c.Err(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for i := len(params) - 1; i >= 0; i-- {
|
for i := len(params) - 1; i >= 0; i-- {
|
||||||
v.Estack().PushVal(params[i])
|
v.Estack().PushVal(params[i])
|
||||||
}
|
}
|
||||||
if runCurrent {
|
if runCurrent {
|
||||||
v.Jump(v.Context(), offset)
|
v.Jump(v.Context(), offset)
|
||||||
if initMD := m.ABI.GetMethod(manifest.MethodInit); initMD != nil {
|
if initMD := m.ABI.GetMethod(manifest.MethodInit, 0); initMD != nil {
|
||||||
v.Call(v.Context(), initMD.Offset)
|
v.Call(v.Context(), initMD.Offset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue