Merge pull request #1673 from nspcc-dev/vm/calltoken

Implement CALLT opcode
This commit is contained in:
Roman Khimov 2021-01-22 09:40:16 +03:00 committed by GitHub
commit 07583332cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 125 additions and 16 deletions

View file

@ -16,6 +16,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
"github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/native" "github.com/nspcc-dev/neo-go/pkg/core/native"
@ -625,6 +626,7 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error
v := systemInterop.SpawnVM() v := systemInterop.SpawnVM()
v.LoadScriptWithFlags(tx.Script, callflag.All) v.LoadScriptWithFlags(tx.Script, callflag.All)
v.SetPriceGetter(bc.getPrice) v.SetPriceGetter(bc.getPrice)
v.LoadToken = contract.LoadToken(systemInterop)
v.GasLimit = tx.SystemFee v.GasLimit = tx.SystemFee
err := v.Run() err := v.Run()
@ -1635,6 +1637,7 @@ func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *
systemInterop := bc.newInteropContext(t, d, b, tx) systemInterop := bc.newInteropContext(t, d, b, tx)
vm := systemInterop.SpawnVM() vm := systemInterop.SpawnVM()
vm.SetPriceGetter(bc.getPrice) vm.SetPriceGetter(bc.getPrice)
vm.LoadToken = contract.LoadToken(systemInterop)
return vm return vm
} }
@ -1670,6 +1673,7 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160,
} }
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit) initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit)
v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates) v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates)
v.Context().NEF = &cs.NEF
v.Jump(v.Context(), md.Offset) v.Jump(v.Context(), md.Offset)
if cs.ID <= 0 { if cs.ID <= 0 {
@ -1704,6 +1708,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa
vm := interopCtx.SpawnVM() vm := interopCtx.SpawnVM()
vm.SetPriceGetter(bc.getPrice) vm.SetPriceGetter(bc.getPrice)
vm.LoadToken = contract.LoadToken(interopCtx)
vm.GasLimit = gas vm.GasLimit = gas
if err := bc.initVerificationVM(interopCtx, hash, witness); err != nil { if err := bc.initVerificationVM(interopCtx, hash, witness); err != nil {
return 0, err return 0, err

View file

@ -15,6 +15,26 @@ import (
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
) )
// LoadToken calls method specified by token id.
func LoadToken(ic *interop.Context) func(id int32) error {
return func(id int32) error {
ctx := ic.VM.Context()
tok := ctx.NEF.Tokens[id]
if int(tok.ParamCount) > ctx.Estack().Len() {
return errors.New("stack is too small")
}
args := make([]stackitem.Item, tok.ParamCount)
for i := range args {
args[i] = ic.VM.Estack().Pop().Item()
}
cs, err := ic.GetContract(tok.Hash)
if err != nil {
return fmt.Errorf("contract not found: %w", err)
}
return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args)
}
}
// Call calls a contract with flags. // Call calls a contract with flags.
func Call(ic *interop.Context) error { func Call(ic *interop.Context) error {
h := ic.VM.Estack().Pop().Bytes() h := ic.VM.Estack().Pop().Bytes()
@ -24,10 +44,6 @@ func Call(ic *interop.Context) error {
return errors.New("call flags out of range") return errors.New("call flags out of range")
} }
args := ic.VM.Estack().Pop().Array() args := ic.VM.Estack().Pop().Array()
return callInternal(ic, h, method, fs, args)
}
func callInternal(ic *interop.Context, h []byte, name string, f callflag.CallFlag, args []stackitem.Item) error {
u, err := util.Uint160DecodeBytesBE(h) u, err := util.Uint160DecodeBytesBE(h)
if err != nil { if err != nil {
return errors.New("invalid contract hash") return errors.New("invalid contract hash")
@ -36,10 +52,10 @@ func callInternal(ic *interop.Context, h []byte, name string, f callflag.CallFla
if err != nil { if err != nil {
return fmt.Errorf("contract not found: %w", err) return fmt.Errorf("contract not found: %w", err)
} }
if strings.HasPrefix(name, "_") { if strings.HasPrefix(method, "_") {
return errors.New("invalid method name (starts with '_')") return errors.New("invalid method name (starts with '_')")
} }
md := cs.Manifest.ABI.GetMethod(name) md := cs.Manifest.ABI.GetMethod(method)
if md == nil { if md == nil {
return errors.New("method not found") return errors.New("method not found")
} }
@ -47,12 +63,18 @@ func callInternal(ic *interop.Context, h []byte, name string, f callflag.CallFla
if !hasReturn { if !hasReturn {
ic.VM.Estack().PushVal(stackitem.Null{}) ic.VM.Estack().PushVal(stackitem.Null{})
} }
return callInternal(ic, cs, method, fs, hasReturn, args)
}
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
hasReturn bool, args []stackitem.Item) error {
md := cs.Manifest.ABI.GetMethod(name)
if md.Safe { if md.Safe {
f &^= callflag.WriteStates f &^= callflag.WriteStates
} else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() { } else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() {
curr, err := ic.GetContract(ic.VM.GetCurrentScriptHash()) curr, err := ic.GetContract(ic.VM.GetCurrentScriptHash())
if err == nil { if err == nil {
if !curr.Manifest.CanCall(u, &cs.Manifest, name) { if !curr.Manifest.CanCall(cs.Hash, &cs.Manifest, name) {
return errors.New("disallowed method call") return errors.New("disallowed method call")
} }
} }
@ -74,6 +96,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra
ic.VM.Invocations[cs.Hash]++ ic.VM.Invocations[cs.Hash]++
ic.VM.LoadScriptWithCallingHash(caller, cs.NEF.Script, cs.Hash, ic.VM.Context().GetCallFlags()&f, true, uint16(len(args))) ic.VM.LoadScriptWithCallingHash(caller, cs.NEF.Script, cs.Hash, ic.VM.Context().GetCallFlags()&f, true, uint16(len(args)))
ic.VM.Context().NEF = &cs.NEF
var isNative bool var isNative bool
for i := range ic.Natives { for i := range ic.Natives {
if ic.Natives[i].Metadata().Hash.Equals(cs.Hash) { if ic.Natives[i].Metadata().Hash.Equals(cs.Hash) {

View file

@ -513,6 +513,12 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
emit.Opcodes(w.BinWriter, opcode.NEWARRAY0, opcode.DUP, opcode.DUP, opcode.APPEND, opcode.NEWMAP) emit.Opcodes(w.BinWriter, opcode.NEWARRAY0, opcode.DUP, opcode.DUP, opcode.APPEND, opcode.NEWMAP)
emit.Syscall(w.BinWriter, interopnames.SystemIteratorCreate) emit.Syscall(w.BinWriter, interopnames.SystemIteratorCreate)
emit.Opcodes(w.BinWriter, opcode.RET) emit.Opcodes(w.BinWriter, opcode.RET)
callT0Off := w.Len()
emit.Opcodes(w.BinWriter, opcode.CALLT, 0, 0, opcode.PUSH1, opcode.ADD, opcode.RET)
callT1Off := w.Len()
emit.Opcodes(w.BinWriter, opcode.CALLT, 1, 0, opcode.RET)
callT2Off := w.Len()
emit.Opcodes(w.BinWriter, opcode.CALLT, 0, 0, opcode.RET)
script := w.Bytes() script := w.Bytes()
h := hash.Hash160(script) h := hash.Hash160(script)
@ -616,7 +622,34 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
Offset: invalidStackOff, Offset: invalidStackOff,
ReturnType: smartcontract.VoidType, ReturnType: smartcontract.VoidType,
}, },
{
Name: "callT0",
Offset: callT0Off,
Parameters: []manifest.Parameter{
manifest.NewParameter("address", smartcontract.Hash160Type),
},
ReturnType: smartcontract.IntegerType,
},
{
Name: "callT1",
Offset: callT1Off,
ReturnType: smartcontract.IntegerType,
},
{
Name: "callT2",
Offset: callT2Off,
ReturnType: smartcontract.IntegerType,
},
} }
m.Permissions = make([]manifest.Permission, 2)
m.Permissions[0].Contract.Type = manifest.PermissionHash
m.Permissions[0].Contract.Value = bc.contracts.NEO.Hash
m.Permissions[0].Methods.Add("balanceOf")
m.Permissions[1].Contract.Type = manifest.PermissionHash
m.Permissions[1].Contract.Value = util.Uint160{}
m.Permissions[1].Methods.Add("method")
cs := &state.Contract{ cs := &state.Contract{
Hash: h, Hash: h,
Manifest: *m, Manifest: *m,
@ -626,6 +659,22 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
ne.Tokens = []nef.MethodToken{
{
Hash: bc.contracts.NEO.Hash,
Method: "balanceOf",
ParamCount: 1,
HasReturn: true,
CallFlag: callflag.ReadStates,
},
{
Hash: util.Uint160{},
Method: "method",
HasReturn: true,
CallFlag: callflag.ReadStates,
},
}
ne.Checksum = ne.CalculateChecksum()
cs.NEF = *ne cs.NEF = *ne
currScript := []byte{byte(opcode.RET)} currScript := []byte{byte(opcode.RET)}
@ -980,3 +1029,28 @@ func TestRuntimeCheckWitness(t *testing.T) {
}) })
}) })
} }
func TestLoadToken(t *testing.T) {
bc := newTestChain(t)
defer bc.Close()
cs, _ := getTestContractState(bc)
require.NoError(t, bc.contracts.Management.PutContractState(bc.dao, cs))
t.Run("good", func(t *testing.T) {
aer, err := invokeContractMethod(bc, 1_00000000, cs.Hash, "callT0", neoOwner.BytesBE())
require.NoError(t, err)
realBalance, _ := bc.GetGoverningTokenBalance(neoOwner)
checkResult(t, aer, stackitem.Make(realBalance.Int64()+1))
})
t.Run("invalid param count", func(t *testing.T) {
aer, err := invokeContractMethod(bc, 1_00000000, cs.Hash, "callT2")
require.NoError(t, err)
checkFAULTState(t, aer)
})
t.Run("invalid contract", func(t *testing.T) {
aer, err := invokeContractMethod(bc, 1_00000000, cs.Hash, "callT1")
require.NoError(t, err)
checkFAULTState(t, aer)
})
}

View file

@ -83,9 +83,7 @@ func TestContractDeploy(t *testing.T) {
cs1.Hash = state.CreateContractHash(testchain.MultisigScriptHash(), cs1.NEF.Script) cs1.Hash = state.CreateContractHash(testchain.MultisigScriptHash(), cs1.NEF.Script)
manif1, err := json.Marshal(cs1.Manifest) manif1, err := json.Marshal(cs1.Manifest)
require.NoError(t, err) require.NoError(t, err)
nef1, err := nef.NewFile(cs1.NEF.Script) nef1b, err := cs1.NEF.Bytes()
require.NoError(t, err)
nef1b, err := nef1.Bytes()
require.NoError(t, err) require.NoError(t, err)
t.Run("no NEF", func(t *testing.T) { t.Run("no NEF", func(t *testing.T) {

View file

@ -7,6 +7,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/nef"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -42,9 +43,6 @@ type Context struct {
// Caller's contract script hash. // Caller's contract script hash.
callingScriptHash util.Uint160 callingScriptHash util.Uint160
// Set to true when running deployed contracts.
isDeployed bool
// Call flags this context was created with. // Call flags this context was created with.
callFlag callflag.CallFlag callFlag callflag.CallFlag
@ -52,6 +50,8 @@ type Context struct {
ParamCount int ParamCount int
// RetCount specifies number of return values. // RetCount specifies number of return values.
RetCount int RetCount int
// NEF represents NEF file for the current contract.
NEF *nef.File
} }
// CheckReturnState represents possible states of stack after opcode.RET was processed. // CheckReturnState represents possible states of stack after opcode.RET was processed.
@ -144,7 +144,7 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) {
opcode.ENDTRY, opcode.ENDTRY,
opcode.INITSSLOT, opcode.LDSFLD, opcode.STSFLD, opcode.LDARG, opcode.STARG, opcode.LDLOC, opcode.STLOC: opcode.INITSSLOT, opcode.LDSFLD, opcode.STSFLD, opcode.LDARG, opcode.STARG, opcode.LDLOC, opcode.STLOC:
numtoread = 1 numtoread = 1
case opcode.INITSLOT, opcode.TRY: case opcode.INITSLOT, opcode.TRY, opcode.CALLT:
numtoread = 2 numtoread = 2
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL, case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL,
opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLTL, opcode.JMPLEL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLTL, opcode.JMPLEL,
@ -273,7 +273,7 @@ func (c *Context) String() string {
// IsDeployed returns whether this context contains deployed contract. // IsDeployed returns whether this context contains deployed contract.
func (c *Context) IsDeployed() bool { func (c *Context) IsDeployed() bool {
return c.isDeployed return c.NEF != nil
} }
// getContextScriptHash returns script hash of the invocation stack element // getContextScriptHash returns script hash of the invocation stack element

View file

@ -80,6 +80,9 @@ type VM struct {
// SyscallHandler handles SYSCALL opcode. // SyscallHandler handles SYSCALL opcode.
SyscallHandler func(v *VM, id uint32) error SyscallHandler func(v *VM, id uint32) error
// LoadToken handles CALLT opcode.
LoadToken func(id int32) error
trigger trigger.Type trigger trigger.Type
// Invocations is a script invocation counter. // Invocations is a script invocation counter.
@ -305,7 +308,6 @@ func (v *VM) LoadScriptWithCallingHash(caller util.Uint160, b []byte, hash util.
f callflag.CallFlag, hasReturn bool, paramCount uint16) { f callflag.CallFlag, hasReturn bool, paramCount uint16) {
v.LoadScriptWithFlags(b, f) v.LoadScriptWithFlags(b, f)
ctx := v.Context() ctx := v.Context()
ctx.isDeployed = true
ctx.scriptHash = hash ctx.scriptHash = hash
ctx.callingScriptHash = caller ctx.callingScriptHash = caller
if hasReturn { if hasReturn {
@ -1276,6 +1278,12 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
v.call(ctx, ptr.Position()) v.call(ctx, ptr.Position())
case opcode.CALLT:
id := int32(binary.LittleEndian.Uint16(parameter))
if err := v.LoadToken(id); err != nil {
panic(err)
}
case opcode.SYSCALL: case opcode.SYSCALL:
interopID := GetInteropID(parameter) interopID := GetInteropID(parameter)
err := v.SyscallHandler(v, interopID) err := v.SyscallHandler(v, interopID)
@ -1510,6 +1518,7 @@ func (v *VM) call(ctx *Context, offset int) {
newCtx.local = nil newCtx.local = nil
newCtx.arguments = nil newCtx.arguments = nil
newCtx.tryStack = NewStack("exception") newCtx.tryStack = NewStack("exception")
newCtx.NEF = ctx.NEF
v.istack.PushVal(newCtx) v.istack.PushVal(newCtx)
v.Jump(newCtx, offset) v.Jump(newCtx, offset)
} }