core: call _deploy method during create/update

This commit is contained in:
Evgenii Stratonikov 2020-10-02 12:56:51 +03:00
parent b71f9e296c
commit 2d9ef9219a
4 changed files with 115 additions and 5 deletions

View file

@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -44,21 +45,28 @@ func callExInternal(ic *interop.Context, h []byte, name string, args []stackitem
if strings.HasPrefix(name, "_") { if strings.HasPrefix(name, "_") {
return errors.New("invalid method name (starts with '_')") return errors.New("invalid method name (starts with '_')")
} }
md := cs.Manifest.ABI.GetMethod(name)
if md == nil {
return fmt.Errorf("method '%s' not found", name)
}
curr, err := ic.DAO.GetContractState(ic.VM.GetCurrentScriptHash()) curr, err := ic.DAO.GetContractState(ic.VM.GetCurrentScriptHash())
if err == nil { if err == nil {
if !curr.Manifest.CanCall(&cs.Manifest, name) { if !curr.Manifest.CanCall(&cs.Manifest, name) {
return errors.New("disallowed method call") return errors.New("disallowed method call")
} }
} }
return CallExInternal(ic, cs, name, args, f)
}
// CallExInternal calls a contract with flags and can't be invoked directly by user.
func CallExInternal(ic *interop.Context, cs *state.Contract,
name string, args []stackitem.Item, f smartcontract.CallFlag) error {
md := cs.Manifest.ABI.GetMethod(name)
if md == nil {
return fmt.Errorf("method '%s' not found", name)
}
if len(args) != len(md.Parameters) { if len(args) != len(md.Parameters) {
return fmt.Errorf("invalid argument count: %d (expected %d)", len(args), len(md.Parameters)) return fmt.Errorf("invalid argument count: %d (expected %d)", len(args), len(md.Parameters))
} }
u := cs.ScriptHash()
ic.Invocations[u]++ ic.Invocations[u]++
ic.VM.LoadScriptWithHash(cs.Script, u, ic.VM.Context().GetCallFlags()&f) ic.VM.LoadScriptWithHash(cs.Script, u, ic.VM.Context().GetCallFlags()&f)
var isNative bool var isNative bool

View file

@ -9,8 +9,10 @@ import (
"github.com/mr-tron/base58" "github.com/mr-tron/base58"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -109,7 +111,7 @@ func contractCreate(ic *interop.Context) error {
return fmt.Errorf("cannot convert contract to stack item: %w", err) return fmt.Errorf("cannot convert contract to stack item: %w", err)
} }
ic.VM.Estack().PushVal(cs) ic.VM.Estack().PushVal(cs)
return nil return callDeploy(ic, newcontract, false)
} }
// contractUpdate migrates a contract. This method assumes that Manifest and Script // contractUpdate migrates a contract. This method assumes that Manifest and Script
@ -183,6 +185,15 @@ func contractUpdate(ic *interop.Context) error {
} }
} }
return callDeploy(ic, contract, true)
}
func callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) error {
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
if md != nil {
return contract.CallExInternal(ic, cs, manifest.MethodDeploy,
[]stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All)
}
return nil return nil
} }

View file

@ -10,6 +10,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/interop/callback" "github.com/nspcc-dev/neo-go/pkg/core/interop/callback"
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract" "github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
"github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames"
"github.com/nspcc-dev/neo-go/pkg/core/interop/runtime" "github.com/nspcc-dev/neo-go/pkg/core/interop/runtime"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
@ -386,10 +387,26 @@ func getTestContractState() (*state.Contract, *state.Contract) {
verifyOff := w.Len() verifyOff := w.Len()
emit.Opcodes(w.BinWriter, opcode.LDSFLD0, opcode.SUB, emit.Opcodes(w.BinWriter, opcode.LDSFLD0, opcode.SUB,
opcode.CONVERT, opcode.Opcode(stackitem.BooleanT), opcode.RET) opcode.CONVERT, opcode.Opcode(stackitem.BooleanT), opcode.RET)
deployOff := w.Len()
emit.Opcodes(w.BinWriter, opcode.JMPIF, 2+8+3)
emit.String(w.BinWriter, "create")
emit.Opcodes(w.BinWriter, opcode.CALL, 3+8+3, opcode.RET)
emit.String(w.BinWriter, "update")
emit.Opcodes(w.BinWriter, opcode.CALL, 3, opcode.RET)
putValOff := w.Len()
emit.String(w.BinWriter, "initial")
emit.Syscall(w.BinWriter, interopnames.SystemStorageGetContext)
emit.Syscall(w.BinWriter, interopnames.SystemStoragePut)
emit.Opcodes(w.BinWriter, opcode.RET)
getValOff := w.Len()
emit.String(w.BinWriter, "initial")
emit.Syscall(w.BinWriter, interopnames.SystemStorageGetContext)
emit.Syscall(w.BinWriter, interopnames.SystemStorageGet)
script := w.Bytes() script := w.Bytes()
h := hash.Hash160(script) h := hash.Hash160(script)
m := manifest.NewManifest(h) m := manifest.NewManifest(h)
m.Features = smartcontract.HasStorage
m.ABI.Methods = []manifest.Method{ m.ABI.Methods = []manifest.Method{
{ {
Name: "add", Name: "add",
@ -439,6 +456,27 @@ func getTestContractState() (*state.Contract, *state.Contract) {
Offset: verifyOff, Offset: verifyOff,
ReturnType: smartcontract.BoolType, ReturnType: smartcontract.BoolType,
}, },
{
Name: manifest.MethodDeploy,
Offset: deployOff,
Parameters: []manifest.Parameter{
manifest.NewParameter("isUpdate", smartcontract.BoolType),
},
ReturnType: smartcontract.VoidType,
},
{
Name: "getValue",
Offset: getValOff,
ReturnType: smartcontract.StringType,
},
{
Name: "putValue",
Offset: putValOff,
Parameters: []manifest.Parameter{
manifest.NewParameter("value", smartcontract.StringType),
},
ReturnType: smartcontract.VoidType,
},
} }
cs := &state.Contract{ cs := &state.Contract{
Script: script, Script: script,
@ -454,6 +492,7 @@ func getTestContractState() (*state.Contract, *state.Contract) {
perm.Methods.Add("add3") perm.Methods.Add("add3")
perm.Methods.Add("invalidReturn") perm.Methods.Add("invalidReturn")
perm.Methods.Add("justReturn") perm.Methods.Add("justReturn")
perm.Methods.Add("getValue")
m.Permissions = append(m.Permissions, *perm) m.Permissions = append(m.Permissions, *perm)
return cs, &state.Contract{ return cs, &state.Contract{
@ -837,6 +876,55 @@ func TestContractUpdate(t *testing.T) {
}) })
} }
// TestContractCreateDeploy checks that `_deploy` method was called
// during contract creation or update.
func TestContractCreateDeploy(t *testing.T) {
v, ic, bc := createVM(t)
defer bc.Close()
v.GasLimit = -1
putArgs := func(cs *state.Contract) {
rawManifest, err := cs.Manifest.MarshalJSON()
require.NoError(t, err)
v.Estack().PushVal(rawManifest)
v.Estack().PushVal(cs.Script)
}
cs, currCs := getTestContractState()
v.LoadScriptWithFlags([]byte{byte(opcode.RET)}, smartcontract.All)
putArgs(cs)
require.NoError(t, contractCreate(ic))
require.NoError(t, ic.VM.Run())
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All)
require.NoError(t, err)
require.NoError(t, v.Run())
require.Equal(t, "create", v.Estack().Pop().String())
v.LoadScriptWithFlags(cs.Script, smartcontract.All)
md := cs.Manifest.ABI.GetMethod("justReturn")
v.Jump(v.Context(), md.Offset)
t.Run("Update", func(t *testing.T) {
newCs := &state.Contract{
ID: cs.ID,
Script: append(cs.Script, byte(opcode.RET)),
Manifest: cs.Manifest,
}
newCs.Manifest.ABI.Hash = hash.Hash160(newCs.Script)
putArgs(newCs)
require.NoError(t, contractUpdate(ic))
require.NoError(t, v.Run())
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All)
require.NoError(t, err)
require.NoError(t, v.Run())
require.Equal(t, "update", v.Estack().Pop().String())
})
}
func TestContractGetCallFlags(t *testing.T) { func TestContractGetCallFlags(t *testing.T) {
v, ic, bc := createVM(t) v, ic, bc := createVM(t)
defer bc.Close() defer bc.Close()

View file

@ -15,6 +15,9 @@ const (
// MethodInit is a name for default initialization method. // MethodInit is a name for default initialization method.
MethodInit = "_initialize" MethodInit = "_initialize"
// MethodDeploy is a name for default method called during contract deployment.
MethodDeploy = "_deploy"
// MethodVerify is a name for default verification method. // MethodVerify is a name for default verification method.
MethodVerify = "verify" MethodVerify = "verify"