diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index e4fa6b0f2..3af24d84b 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -16,6 +16,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/core/interop/contract" "github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/native" @@ -625,6 +626,7 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error v := systemInterop.SpawnVM() v.LoadScriptWithFlags(tx.Script, callflag.All) v.SetPriceGetter(bc.getPrice) + v.LoadToken = contract.LoadToken(systemInterop) v.GasLimit = tx.SystemFee err := v.Run() @@ -1635,6 +1637,7 @@ func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b * systemInterop := bc.newInteropContext(t, d, b, tx) vm := systemInterop.SpawnVM() vm.SetPriceGetter(bc.getPrice) + vm.LoadToken = contract.LoadToken(systemInterop) return vm } @@ -1670,6 +1673,7 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160, } initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit) v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates) + v.Context().NEF = &cs.NEF v.Jump(v.Context(), md.Offset) if cs.ID <= 0 { @@ -1704,6 +1708,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa vm := interopCtx.SpawnVM() vm.SetPriceGetter(bc.getPrice) + vm.LoadToken = contract.LoadToken(interopCtx) vm.GasLimit = gas if err := bc.initVerificationVM(interopCtx, hash, witness); err != nil { return 0, err diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index d0b155544..6603f337d 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -15,6 +15,26 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) +// LoadToken calls method specified by token id. +func LoadToken(ic *interop.Context) func(id int32) error { + return func(id int32) error { + ctx := ic.VM.Context() + tok := ctx.NEF.Tokens[id] + if int(tok.ParamCount) > ctx.Estack().Len() { + return errors.New("stack is too small") + } + args := make([]stackitem.Item, tok.ParamCount) + for i := range args { + args[i] = ic.VM.Estack().Pop().Item() + } + cs, err := ic.GetContract(tok.Hash) + if err != nil { + return fmt.Errorf("contract not found: %w", err) + } + return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args) + } +} + // Call calls a contract with flags. func Call(ic *interop.Context) error { h := ic.VM.Estack().Pop().Bytes() @@ -24,10 +44,6 @@ func Call(ic *interop.Context) error { return errors.New("call flags out of range") } args := ic.VM.Estack().Pop().Array() - return callInternal(ic, h, method, fs, args) -} - -func callInternal(ic *interop.Context, h []byte, name string, f callflag.CallFlag, args []stackitem.Item) error { u, err := util.Uint160DecodeBytesBE(h) if err != nil { return errors.New("invalid contract hash") @@ -36,10 +52,10 @@ func callInternal(ic *interop.Context, h []byte, name string, f callflag.CallFla if err != nil { return fmt.Errorf("contract not found: %w", err) } - if strings.HasPrefix(name, "_") { + if strings.HasPrefix(method, "_") { return errors.New("invalid method name (starts with '_')") } - md := cs.Manifest.ABI.GetMethod(name) + md := cs.Manifest.ABI.GetMethod(method) if md == nil { return errors.New("method not found") } @@ -47,12 +63,18 @@ func callInternal(ic *interop.Context, h []byte, name string, f callflag.CallFla if !hasReturn { ic.VM.Estack().PushVal(stackitem.Null{}) } + return callInternal(ic, cs, method, fs, hasReturn, args) +} + +func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, + hasReturn bool, args []stackitem.Item) error { + md := cs.Manifest.ABI.GetMethod(name) if md.Safe { f &^= callflag.WriteStates } else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() { curr, err := ic.GetContract(ic.VM.GetCurrentScriptHash()) if err == nil { - if !curr.Manifest.CanCall(u, &cs.Manifest, name) { + if !curr.Manifest.CanCall(cs.Hash, &cs.Manifest, name) { return errors.New("disallowed method call") } } @@ -74,6 +96,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra ic.VM.Invocations[cs.Hash]++ ic.VM.LoadScriptWithCallingHash(caller, cs.NEF.Script, cs.Hash, ic.VM.Context().GetCallFlags()&f, true, uint16(len(args))) + ic.VM.Context().NEF = &cs.NEF var isNative bool for i := range ic.Natives { if ic.Natives[i].Metadata().Hash.Equals(cs.Hash) { diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index d4a333f58..b30e80a3f 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -513,6 +513,12 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) { emit.Opcodes(w.BinWriter, opcode.NEWARRAY0, opcode.DUP, opcode.DUP, opcode.APPEND, opcode.NEWMAP) emit.Syscall(w.BinWriter, interopnames.SystemIteratorCreate) emit.Opcodes(w.BinWriter, opcode.RET) + callT0Off := w.Len() + emit.Opcodes(w.BinWriter, opcode.CALLT, 0, 0, opcode.PUSH1, opcode.ADD, opcode.RET) + callT1Off := w.Len() + emit.Opcodes(w.BinWriter, opcode.CALLT, 1, 0, opcode.RET) + callT2Off := w.Len() + emit.Opcodes(w.BinWriter, opcode.CALLT, 0, 0, opcode.RET) script := w.Bytes() h := hash.Hash160(script) @@ -616,7 +622,34 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) { Offset: invalidStackOff, ReturnType: smartcontract.VoidType, }, + { + Name: "callT0", + Offset: callT0Off, + Parameters: []manifest.Parameter{ + manifest.NewParameter("address", smartcontract.Hash160Type), + }, + ReturnType: smartcontract.IntegerType, + }, + { + Name: "callT1", + Offset: callT1Off, + ReturnType: smartcontract.IntegerType, + }, + { + Name: "callT2", + Offset: callT2Off, + ReturnType: smartcontract.IntegerType, + }, } + m.Permissions = make([]manifest.Permission, 2) + m.Permissions[0].Contract.Type = manifest.PermissionHash + m.Permissions[0].Contract.Value = bc.contracts.NEO.Hash + m.Permissions[0].Methods.Add("balanceOf") + + m.Permissions[1].Contract.Type = manifest.PermissionHash + m.Permissions[1].Contract.Value = util.Uint160{} + m.Permissions[1].Methods.Add("method") + cs := &state.Contract{ Hash: h, Manifest: *m, @@ -626,6 +659,22 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) { if err != nil { panic(err) } + ne.Tokens = []nef.MethodToken{ + { + Hash: bc.contracts.NEO.Hash, + Method: "balanceOf", + ParamCount: 1, + HasReturn: true, + CallFlag: callflag.ReadStates, + }, + { + Hash: util.Uint160{}, + Method: "method", + HasReturn: true, + CallFlag: callflag.ReadStates, + }, + } + ne.Checksum = ne.CalculateChecksum() cs.NEF = *ne currScript := []byte{byte(opcode.RET)} @@ -980,3 +1029,28 @@ func TestRuntimeCheckWitness(t *testing.T) { }) }) } + +func TestLoadToken(t *testing.T) { + bc := newTestChain(t) + defer bc.Close() + + cs, _ := getTestContractState(bc) + require.NoError(t, bc.contracts.Management.PutContractState(bc.dao, cs)) + + t.Run("good", func(t *testing.T) { + aer, err := invokeContractMethod(bc, 1_00000000, cs.Hash, "callT0", neoOwner.BytesBE()) + require.NoError(t, err) + realBalance, _ := bc.GetGoverningTokenBalance(neoOwner) + checkResult(t, aer, stackitem.Make(realBalance.Int64()+1)) + }) + t.Run("invalid param count", func(t *testing.T) { + aer, err := invokeContractMethod(bc, 1_00000000, cs.Hash, "callT2") + require.NoError(t, err) + checkFAULTState(t, aer) + }) + t.Run("invalid contract", func(t *testing.T) { + aer, err := invokeContractMethod(bc, 1_00000000, cs.Hash, "callT1") + require.NoError(t, err) + checkFAULTState(t, aer) + }) +} diff --git a/pkg/core/native_management_test.go b/pkg/core/native_management_test.go index 4a208713d..687998e1f 100644 --- a/pkg/core/native_management_test.go +++ b/pkg/core/native_management_test.go @@ -83,9 +83,7 @@ func TestContractDeploy(t *testing.T) { cs1.Hash = state.CreateContractHash(testchain.MultisigScriptHash(), cs1.NEF.Script) manif1, err := json.Marshal(cs1.Manifest) require.NoError(t, err) - nef1, err := nef.NewFile(cs1.NEF.Script) - require.NoError(t, err) - nef1b, err := nef1.Bytes() + nef1b, err := cs1.NEF.Bytes() require.NoError(t, err) t.Run("no NEF", func(t *testing.T) { diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 340a6a5b2..dcbee69ab 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -7,6 +7,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -42,9 +43,6 @@ type Context struct { // Caller's contract script hash. callingScriptHash util.Uint160 - // Set to true when running deployed contracts. - isDeployed bool - // Call flags this context was created with. callFlag callflag.CallFlag @@ -52,6 +50,8 @@ type Context struct { ParamCount int // RetCount specifies number of return values. RetCount int + // NEF represents NEF file for the current contract. + NEF *nef.File } // CheckReturnState represents possible states of stack after opcode.RET was processed. @@ -144,7 +144,7 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { opcode.ENDTRY, opcode.INITSSLOT, opcode.LDSFLD, opcode.STSFLD, opcode.LDARG, opcode.STARG, opcode.LDLOC, opcode.STLOC: numtoread = 1 - case opcode.INITSLOT, opcode.TRY: + case opcode.INITSLOT, opcode.TRY, opcode.CALLT: numtoread = 2 case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLTL, opcode.JMPLEL, @@ -273,7 +273,7 @@ func (c *Context) String() string { // IsDeployed returns whether this context contains deployed contract. func (c *Context) IsDeployed() bool { - return c.isDeployed + return c.NEF != nil } // getContextScriptHash returns script hash of the invocation stack element diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 6e92c2298..a1112a7b4 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -80,6 +80,9 @@ type VM struct { // SyscallHandler handles SYSCALL opcode. SyscallHandler func(v *VM, id uint32) error + // LoadToken handles CALLT opcode. + LoadToken func(id int32) error + trigger trigger.Type // Invocations is a script invocation counter. @@ -305,7 +308,6 @@ func (v *VM) LoadScriptWithCallingHash(caller util.Uint160, b []byte, hash util. f callflag.CallFlag, hasReturn bool, paramCount uint16) { v.LoadScriptWithFlags(b, f) ctx := v.Context() - ctx.isDeployed = true ctx.scriptHash = hash ctx.callingScriptHash = caller if hasReturn { @@ -1276,6 +1278,12 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro v.call(ctx, ptr.Position()) + case opcode.CALLT: + id := int32(binary.LittleEndian.Uint16(parameter)) + if err := v.LoadToken(id); err != nil { + panic(err) + } + case opcode.SYSCALL: interopID := GetInteropID(parameter) err := v.SyscallHandler(v, interopID) @@ -1510,6 +1518,7 @@ func (v *VM) call(ctx *Context, offset int) { newCtx.local = nil newCtx.arguments = nil newCtx.tryStack = NewStack("exception") + newCtx.NEF = ctx.NEF v.istack.PushVal(newCtx) v.Jump(newCtx, offset) }