mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-26 19:42:23 +00:00
core: call _deploy method during create/update
This commit is contained in:
parent
b71f9e296c
commit
2d9ef9219a
4 changed files with 115 additions and 5 deletions
|
@ -6,6 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
@ -44,21 +45,28 @@ func callExInternal(ic *interop.Context, h []byte, name string, args []stackitem
|
||||||
if strings.HasPrefix(name, "_") {
|
if strings.HasPrefix(name, "_") {
|
||||||
return errors.New("invalid method name (starts with '_')")
|
return errors.New("invalid method name (starts with '_')")
|
||||||
}
|
}
|
||||||
md := cs.Manifest.ABI.GetMethod(name)
|
|
||||||
if md == nil {
|
|
||||||
return fmt.Errorf("method '%s' not found", name)
|
|
||||||
}
|
|
||||||
curr, err := ic.DAO.GetContractState(ic.VM.GetCurrentScriptHash())
|
curr, err := ic.DAO.GetContractState(ic.VM.GetCurrentScriptHash())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if !curr.Manifest.CanCall(&cs.Manifest, name) {
|
if !curr.Manifest.CanCall(&cs.Manifest, name) {
|
||||||
return errors.New("disallowed method call")
|
return errors.New("disallowed method call")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return CallExInternal(ic, cs, name, args, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallExInternal calls a contract with flags and can't be invoked directly by user.
|
||||||
|
func CallExInternal(ic *interop.Context, cs *state.Contract,
|
||||||
|
name string, args []stackitem.Item, f smartcontract.CallFlag) error {
|
||||||
|
md := cs.Manifest.ABI.GetMethod(name)
|
||||||
|
if md == nil {
|
||||||
|
return fmt.Errorf("method '%s' not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
if len(args) != len(md.Parameters) {
|
if len(args) != len(md.Parameters) {
|
||||||
return fmt.Errorf("invalid argument count: %d (expected %d)", len(args), len(md.Parameters))
|
return fmt.Errorf("invalid argument count: %d (expected %d)", len(args), len(md.Parameters))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
u := cs.ScriptHash()
|
||||||
ic.Invocations[u]++
|
ic.Invocations[u]++
|
||||||
ic.VM.LoadScriptWithHash(cs.Script, u, ic.VM.Context().GetCallFlags()&f)
|
ic.VM.LoadScriptWithHash(cs.Script, u, ic.VM.Context().GetCallFlags()&f)
|
||||||
var isNative bool
|
var isNative bool
|
||||||
|
|
|
@ -9,8 +9,10 @@ import (
|
||||||
|
|
||||||
"github.com/mr-tron/base58"
|
"github.com/mr-tron/base58"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
|
@ -109,7 +111,7 @@ func contractCreate(ic *interop.Context) error {
|
||||||
return fmt.Errorf("cannot convert contract to stack item: %w", err)
|
return fmt.Errorf("cannot convert contract to stack item: %w", err)
|
||||||
}
|
}
|
||||||
ic.VM.Estack().PushVal(cs)
|
ic.VM.Estack().PushVal(cs)
|
||||||
return nil
|
return callDeploy(ic, newcontract, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// contractUpdate migrates a contract. This method assumes that Manifest and Script
|
// contractUpdate migrates a contract. This method assumes that Manifest and Script
|
||||||
|
@ -183,6 +185,15 @@ func contractUpdate(ic *interop.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return callDeploy(ic, contract, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) error {
|
||||||
|
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
|
||||||
|
if md != nil {
|
||||||
|
return contract.CallExInternal(ic, cs, manifest.MethodDeploy,
|
||||||
|
[]stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/callback"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop/callback"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/runtime"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop/runtime"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
@ -386,10 +387,26 @@ func getTestContractState() (*state.Contract, *state.Contract) {
|
||||||
verifyOff := w.Len()
|
verifyOff := w.Len()
|
||||||
emit.Opcodes(w.BinWriter, opcode.LDSFLD0, opcode.SUB,
|
emit.Opcodes(w.BinWriter, opcode.LDSFLD0, opcode.SUB,
|
||||||
opcode.CONVERT, opcode.Opcode(stackitem.BooleanT), opcode.RET)
|
opcode.CONVERT, opcode.Opcode(stackitem.BooleanT), opcode.RET)
|
||||||
|
deployOff := w.Len()
|
||||||
|
emit.Opcodes(w.BinWriter, opcode.JMPIF, 2+8+3)
|
||||||
|
emit.String(w.BinWriter, "create")
|
||||||
|
emit.Opcodes(w.BinWriter, opcode.CALL, 3+8+3, opcode.RET)
|
||||||
|
emit.String(w.BinWriter, "update")
|
||||||
|
emit.Opcodes(w.BinWriter, opcode.CALL, 3, opcode.RET)
|
||||||
|
putValOff := w.Len()
|
||||||
|
emit.String(w.BinWriter, "initial")
|
||||||
|
emit.Syscall(w.BinWriter, interopnames.SystemStorageGetContext)
|
||||||
|
emit.Syscall(w.BinWriter, interopnames.SystemStoragePut)
|
||||||
|
emit.Opcodes(w.BinWriter, opcode.RET)
|
||||||
|
getValOff := w.Len()
|
||||||
|
emit.String(w.BinWriter, "initial")
|
||||||
|
emit.Syscall(w.BinWriter, interopnames.SystemStorageGetContext)
|
||||||
|
emit.Syscall(w.BinWriter, interopnames.SystemStorageGet)
|
||||||
|
|
||||||
script := w.Bytes()
|
script := w.Bytes()
|
||||||
h := hash.Hash160(script)
|
h := hash.Hash160(script)
|
||||||
m := manifest.NewManifest(h)
|
m := manifest.NewManifest(h)
|
||||||
|
m.Features = smartcontract.HasStorage
|
||||||
m.ABI.Methods = []manifest.Method{
|
m.ABI.Methods = []manifest.Method{
|
||||||
{
|
{
|
||||||
Name: "add",
|
Name: "add",
|
||||||
|
@ -439,6 +456,27 @@ func getTestContractState() (*state.Contract, *state.Contract) {
|
||||||
Offset: verifyOff,
|
Offset: verifyOff,
|
||||||
ReturnType: smartcontract.BoolType,
|
ReturnType: smartcontract.BoolType,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: manifest.MethodDeploy,
|
||||||
|
Offset: deployOff,
|
||||||
|
Parameters: []manifest.Parameter{
|
||||||
|
manifest.NewParameter("isUpdate", smartcontract.BoolType),
|
||||||
|
},
|
||||||
|
ReturnType: smartcontract.VoidType,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "getValue",
|
||||||
|
Offset: getValOff,
|
||||||
|
ReturnType: smartcontract.StringType,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "putValue",
|
||||||
|
Offset: putValOff,
|
||||||
|
Parameters: []manifest.Parameter{
|
||||||
|
manifest.NewParameter("value", smartcontract.StringType),
|
||||||
|
},
|
||||||
|
ReturnType: smartcontract.VoidType,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
cs := &state.Contract{
|
cs := &state.Contract{
|
||||||
Script: script,
|
Script: script,
|
||||||
|
@ -454,6 +492,7 @@ func getTestContractState() (*state.Contract, *state.Contract) {
|
||||||
perm.Methods.Add("add3")
|
perm.Methods.Add("add3")
|
||||||
perm.Methods.Add("invalidReturn")
|
perm.Methods.Add("invalidReturn")
|
||||||
perm.Methods.Add("justReturn")
|
perm.Methods.Add("justReturn")
|
||||||
|
perm.Methods.Add("getValue")
|
||||||
m.Permissions = append(m.Permissions, *perm)
|
m.Permissions = append(m.Permissions, *perm)
|
||||||
|
|
||||||
return cs, &state.Contract{
|
return cs, &state.Contract{
|
||||||
|
@ -837,6 +876,55 @@ func TestContractUpdate(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestContractCreateDeploy checks that `_deploy` method was called
|
||||||
|
// during contract creation or update.
|
||||||
|
func TestContractCreateDeploy(t *testing.T) {
|
||||||
|
v, ic, bc := createVM(t)
|
||||||
|
defer bc.Close()
|
||||||
|
v.GasLimit = -1
|
||||||
|
|
||||||
|
putArgs := func(cs *state.Contract) {
|
||||||
|
rawManifest, err := cs.Manifest.MarshalJSON()
|
||||||
|
require.NoError(t, err)
|
||||||
|
v.Estack().PushVal(rawManifest)
|
||||||
|
v.Estack().PushVal(cs.Script)
|
||||||
|
}
|
||||||
|
cs, currCs := getTestContractState()
|
||||||
|
|
||||||
|
v.LoadScriptWithFlags([]byte{byte(opcode.RET)}, smartcontract.All)
|
||||||
|
putArgs(cs)
|
||||||
|
require.NoError(t, contractCreate(ic))
|
||||||
|
require.NoError(t, ic.VM.Run())
|
||||||
|
|
||||||
|
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
|
||||||
|
err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, v.Run())
|
||||||
|
require.Equal(t, "create", v.Estack().Pop().String())
|
||||||
|
|
||||||
|
v.LoadScriptWithFlags(cs.Script, smartcontract.All)
|
||||||
|
md := cs.Manifest.ABI.GetMethod("justReturn")
|
||||||
|
v.Jump(v.Context(), md.Offset)
|
||||||
|
|
||||||
|
t.Run("Update", func(t *testing.T) {
|
||||||
|
newCs := &state.Contract{
|
||||||
|
ID: cs.ID,
|
||||||
|
Script: append(cs.Script, byte(opcode.RET)),
|
||||||
|
Manifest: cs.Manifest,
|
||||||
|
}
|
||||||
|
newCs.Manifest.ABI.Hash = hash.Hash160(newCs.Script)
|
||||||
|
putArgs(newCs)
|
||||||
|
require.NoError(t, contractUpdate(ic))
|
||||||
|
require.NoError(t, v.Run())
|
||||||
|
|
||||||
|
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
|
||||||
|
err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, v.Run())
|
||||||
|
require.Equal(t, "update", v.Estack().Pop().String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestContractGetCallFlags(t *testing.T) {
|
func TestContractGetCallFlags(t *testing.T) {
|
||||||
v, ic, bc := createVM(t)
|
v, ic, bc := createVM(t)
|
||||||
defer bc.Close()
|
defer bc.Close()
|
||||||
|
|
|
@ -15,6 +15,9 @@ const (
|
||||||
// MethodInit is a name for default initialization method.
|
// MethodInit is a name for default initialization method.
|
||||||
MethodInit = "_initialize"
|
MethodInit = "_initialize"
|
||||||
|
|
||||||
|
// MethodDeploy is a name for default method called during contract deployment.
|
||||||
|
MethodDeploy = "_deploy"
|
||||||
|
|
||||||
// MethodVerify is a name for default verification method.
|
// MethodVerify is a name for default verification method.
|
||||||
MethodVerify = "verify"
|
MethodVerify = "verify"
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue