mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-22 19:29:39 +00:00
core: call _deploy method during create/update
This commit is contained in:
parent
b71f9e296c
commit
2d9ef9219a
4 changed files with 115 additions and 5 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
|
@ -44,21 +45,28 @@ func callExInternal(ic *interop.Context, h []byte, name string, args []stackitem
|
|||
if strings.HasPrefix(name, "_") {
|
||||
return errors.New("invalid method name (starts with '_')")
|
||||
}
|
||||
md := cs.Manifest.ABI.GetMethod(name)
|
||||
if md == nil {
|
||||
return fmt.Errorf("method '%s' not found", name)
|
||||
}
|
||||
curr, err := ic.DAO.GetContractState(ic.VM.GetCurrentScriptHash())
|
||||
if err == nil {
|
||||
if !curr.Manifest.CanCall(&cs.Manifest, name) {
|
||||
return errors.New("disallowed method call")
|
||||
}
|
||||
}
|
||||
return CallExInternal(ic, cs, name, args, f)
|
||||
}
|
||||
|
||||
// CallExInternal calls a contract with flags and can't be invoked directly by user.
|
||||
func CallExInternal(ic *interop.Context, cs *state.Contract,
|
||||
name string, args []stackitem.Item, f smartcontract.CallFlag) error {
|
||||
md := cs.Manifest.ABI.GetMethod(name)
|
||||
if md == nil {
|
||||
return fmt.Errorf("method '%s' not found", name)
|
||||
}
|
||||
|
||||
if len(args) != len(md.Parameters) {
|
||||
return fmt.Errorf("invalid argument count: %d (expected %d)", len(args), len(md.Parameters))
|
||||
}
|
||||
|
||||
u := cs.ScriptHash()
|
||||
ic.Invocations[u]++
|
||||
ic.VM.LoadScriptWithHash(cs.Script, u, ic.VM.Context().GetCallFlags()&f)
|
||||
var isNative bool
|
||||
|
|
|
@ -9,8 +9,10 @@ import (
|
|||
|
||||
"github.com/mr-tron/base58"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
|
@ -109,7 +111,7 @@ func contractCreate(ic *interop.Context) error {
|
|||
return fmt.Errorf("cannot convert contract to stack item: %w", err)
|
||||
}
|
||||
ic.VM.Estack().PushVal(cs)
|
||||
return nil
|
||||
return callDeploy(ic, newcontract, false)
|
||||
}
|
||||
|
||||
// contractUpdate migrates a contract. This method assumes that Manifest and Script
|
||||
|
@ -183,6 +185,15 @@ func contractUpdate(ic *interop.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
return callDeploy(ic, contract, true)
|
||||
}
|
||||
|
||||
func callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) error {
|
||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
|
||||
if md != nil {
|
||||
return contract.CallExInternal(ic, cs, manifest.MethodDeploy,
|
||||
[]stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/callback"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/runtime"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
|
@ -386,10 +387,26 @@ func getTestContractState() (*state.Contract, *state.Contract) {
|
|||
verifyOff := w.Len()
|
||||
emit.Opcodes(w.BinWriter, opcode.LDSFLD0, opcode.SUB,
|
||||
opcode.CONVERT, opcode.Opcode(stackitem.BooleanT), opcode.RET)
|
||||
deployOff := w.Len()
|
||||
emit.Opcodes(w.BinWriter, opcode.JMPIF, 2+8+3)
|
||||
emit.String(w.BinWriter, "create")
|
||||
emit.Opcodes(w.BinWriter, opcode.CALL, 3+8+3, opcode.RET)
|
||||
emit.String(w.BinWriter, "update")
|
||||
emit.Opcodes(w.BinWriter, opcode.CALL, 3, opcode.RET)
|
||||
putValOff := w.Len()
|
||||
emit.String(w.BinWriter, "initial")
|
||||
emit.Syscall(w.BinWriter, interopnames.SystemStorageGetContext)
|
||||
emit.Syscall(w.BinWriter, interopnames.SystemStoragePut)
|
||||
emit.Opcodes(w.BinWriter, opcode.RET)
|
||||
getValOff := w.Len()
|
||||
emit.String(w.BinWriter, "initial")
|
||||
emit.Syscall(w.BinWriter, interopnames.SystemStorageGetContext)
|
||||
emit.Syscall(w.BinWriter, interopnames.SystemStorageGet)
|
||||
|
||||
script := w.Bytes()
|
||||
h := hash.Hash160(script)
|
||||
m := manifest.NewManifest(h)
|
||||
m.Features = smartcontract.HasStorage
|
||||
m.ABI.Methods = []manifest.Method{
|
||||
{
|
||||
Name: "add",
|
||||
|
@ -439,6 +456,27 @@ func getTestContractState() (*state.Contract, *state.Contract) {
|
|||
Offset: verifyOff,
|
||||
ReturnType: smartcontract.BoolType,
|
||||
},
|
||||
{
|
||||
Name: manifest.MethodDeploy,
|
||||
Offset: deployOff,
|
||||
Parameters: []manifest.Parameter{
|
||||
manifest.NewParameter("isUpdate", smartcontract.BoolType),
|
||||
},
|
||||
ReturnType: smartcontract.VoidType,
|
||||
},
|
||||
{
|
||||
Name: "getValue",
|
||||
Offset: getValOff,
|
||||
ReturnType: smartcontract.StringType,
|
||||
},
|
||||
{
|
||||
Name: "putValue",
|
||||
Offset: putValOff,
|
||||
Parameters: []manifest.Parameter{
|
||||
manifest.NewParameter("value", smartcontract.StringType),
|
||||
},
|
||||
ReturnType: smartcontract.VoidType,
|
||||
},
|
||||
}
|
||||
cs := &state.Contract{
|
||||
Script: script,
|
||||
|
@ -454,6 +492,7 @@ func getTestContractState() (*state.Contract, *state.Contract) {
|
|||
perm.Methods.Add("add3")
|
||||
perm.Methods.Add("invalidReturn")
|
||||
perm.Methods.Add("justReturn")
|
||||
perm.Methods.Add("getValue")
|
||||
m.Permissions = append(m.Permissions, *perm)
|
||||
|
||||
return cs, &state.Contract{
|
||||
|
@ -837,6 +876,55 @@ func TestContractUpdate(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// TestContractCreateDeploy checks that `_deploy` method was called
|
||||
// during contract creation or update.
|
||||
func TestContractCreateDeploy(t *testing.T) {
|
||||
v, ic, bc := createVM(t)
|
||||
defer bc.Close()
|
||||
v.GasLimit = -1
|
||||
|
||||
putArgs := func(cs *state.Contract) {
|
||||
rawManifest, err := cs.Manifest.MarshalJSON()
|
||||
require.NoError(t, err)
|
||||
v.Estack().PushVal(rawManifest)
|
||||
v.Estack().PushVal(cs.Script)
|
||||
}
|
||||
cs, currCs := getTestContractState()
|
||||
|
||||
v.LoadScriptWithFlags([]byte{byte(opcode.RET)}, smartcontract.All)
|
||||
putArgs(cs)
|
||||
require.NoError(t, contractCreate(ic))
|
||||
require.NoError(t, ic.VM.Run())
|
||||
|
||||
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
|
||||
err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, v.Run())
|
||||
require.Equal(t, "create", v.Estack().Pop().String())
|
||||
|
||||
v.LoadScriptWithFlags(cs.Script, smartcontract.All)
|
||||
md := cs.Manifest.ABI.GetMethod("justReturn")
|
||||
v.Jump(v.Context(), md.Offset)
|
||||
|
||||
t.Run("Update", func(t *testing.T) {
|
||||
newCs := &state.Contract{
|
||||
ID: cs.ID,
|
||||
Script: append(cs.Script, byte(opcode.RET)),
|
||||
Manifest: cs.Manifest,
|
||||
}
|
||||
newCs.Manifest.ABI.Hash = hash.Hash160(newCs.Script)
|
||||
putArgs(newCs)
|
||||
require.NoError(t, contractUpdate(ic))
|
||||
require.NoError(t, v.Run())
|
||||
|
||||
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
|
||||
err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, v.Run())
|
||||
require.Equal(t, "update", v.Estack().Pop().String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractGetCallFlags(t *testing.T) {
|
||||
v, ic, bc := createVM(t)
|
||||
defer bc.Close()
|
||||
|
|
|
@ -15,6 +15,9 @@ const (
|
|||
// MethodInit is a name for default initialization method.
|
||||
MethodInit = "_initialize"
|
||||
|
||||
// MethodDeploy is a name for default method called during contract deployment.
|
||||
MethodDeploy = "_deploy"
|
||||
|
||||
// MethodVerify is a name for default verification method.
|
||||
MethodVerify = "verify"
|
||||
|
||||
|
|
Loading…
Reference in a new issue