*: support invoking methods by offset

Allow to invoke methods by offset:
1. Every invoked contract must have manifest.
2. Check arguments count on invocation.
3. Change AppCall to a regular syscall.
4. Add test suite for `System.Contract.Call`.
This commit is contained in:
Evgenii Stratonikov 2020-07-23 18:13:02 +03:00
parent e87eba51f9
commit d2ddf7b7cb
14 changed files with 272 additions and 117 deletions

View file

@ -16,7 +16,6 @@ var (
goBuiltins = []string{"len", "append", "panic"} goBuiltins = []string{"len", "append", "panic"}
// Custom builtin utility functions. // Custom builtin utility functions.
customBuiltins = []string{ customBuiltins = []string{
"AppCall",
"FromAddress", "Equals", "FromAddress", "Equals",
"ToBool", "ToByteArray", "ToInteger", "ToBool", "ToByteArray", "ToInteger",
} }

View file

@ -1220,13 +1220,6 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) {
typ = stackitem.BooleanT typ = stackitem.BooleanT
} }
c.emitConvert(typ) c.emitConvert(typ)
case "AppCall":
c.emitReverse(len(expr.Args))
buf := c.getByteArray(expr.Args[0])
if buf != nil && len(buf) != 20 {
c.prog.Err = errors.New("invalid script hash")
}
emit.Syscall(c.prog.BinWriter, "System.Contract.Call")
case "Equals": case "Equals":
emit.Opcode(c.prog.BinWriter, opcode.EQUAL) emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
case "FromAddress": case "FromAddress":

View file

@ -74,17 +74,25 @@ func TestAppCall(t *testing.T) {
srcInner := ` srcInner := `
package foo package foo
func Main(a []byte, b []byte) []byte { func Main(a []byte, b []byte) []byte {
panic("Main was called")
}
func Append(a []byte, b []byte) []byte {
return append(a, b...) return append(a, b...)
} }
` `
inner, err := compiler.Compile(strings.NewReader(srcInner)) inner, di, err := compiler.CompileWithDebugInfo(strings.NewReader(srcInner))
require.NoError(t, err)
m, err := di.ConvertToManifest(smartcontract.NoProperties)
require.NoError(t, err) require.NoError(t, err)
ic := interop.NewContext(trigger.Application, nil, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), nil, nil, nil, zaptest.NewLogger(t))
require.NoError(t, ic.DAO.PutContractState(&state.Contract{Script: inner}))
ih := hash.Hash160(inner) ih := hash.Hash160(inner)
ic := interop.NewContext(trigger.Application, nil, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), nil, nil, nil, zaptest.NewLogger(t))
require.NoError(t, ic.DAO.PutContractState(&state.Contract{
Script: inner,
Manifest: *m,
}))
t.Run("valid script", func(t *testing.T) { t.Run("valid script", func(t *testing.T) {
src := getAppCallScript(fmt.Sprintf("%#v", ih.BytesBE())) src := getAppCallScript(fmt.Sprintf("%#v", ih.BytesBE()))
v := spawnVM(t, ic, src) v := spawnVM(t, ic, src)
@ -102,13 +110,6 @@ func TestAppCall(t *testing.T) {
require.Error(t, v.Run()) require.Error(t, v.Run())
}) })
t.Run("invalid script address", func(t *testing.T) {
src := getAppCallScript("[]byte{1, 2, 3}")
_, err := compiler.Compile(strings.NewReader(src))
require.Error(t, err)
})
t.Run("convert from string constant", func(t *testing.T) { t.Run("convert from string constant", func(t *testing.T) {
src := ` src := `
package foo package foo
@ -117,7 +118,7 @@ func TestAppCall(t *testing.T) {
func Main() []byte { func Main() []byte {
x := []byte{1, 2} x := []byte{1, 2}
y := []byte{3, 4} y := []byte{3, 4}
result := engine.AppCall([]byte(scriptHash), x, y) result := engine.AppCall([]byte(scriptHash), "append", x, y)
return result.([]byte) return result.([]byte)
} }
` `
@ -136,7 +137,7 @@ func TestAppCall(t *testing.T) {
x := []byte{1, 2} x := []byte{1, 2}
y := []byte{3, 4} y := []byte{3, 4}
var addr = []byte(` + fmt.Sprintf("%#v", string(ih.BytesBE())) + `) var addr = []byte(` + fmt.Sprintf("%#v", string(ih.BytesBE())) + `)
result := engine.AppCall(addr, x, y) result := engine.AppCall(addr, "append", x, y)
return result.([]byte) return result.([]byte)
} }
` `
@ -155,7 +156,7 @@ func getAppCallScript(h string) string {
func Main() []byte { func Main() []byte {
x := []byte{1, 2} x := []byte{1, 2}
y := []byte{3, 4} y := []byte{3, 4}
result := engine.AppCall(` + h + `, x, y) result := engine.AppCall(` + h + `, "append", x, y)
return result.([]byte) return result.([]byte)
} }
` `

View file

@ -44,6 +44,9 @@ var syscalls = map[string]map[string]Syscall{
"Next": {"System.Enumerator.Next", false}, "Next": {"System.Enumerator.Next", false},
"Value": {"System.Enumerator.Value", false}, "Value": {"System.Enumerator.Value", false},
}, },
"engine": {
"AppCall": {"System.Contract.Call", false},
},
"iterator": { "iterator": {
"Concat": {"System.Iterator.Concat", false}, "Concat": {"System.Iterator.Concat", false},
"Create": {"System.Iterator.Create", false}, "Create": {"System.Iterator.Create", false},

View file

@ -252,7 +252,7 @@ func TestCreateBasicChain(t *testing.T) {
// Now invoke this contract. // Now invoke this contract.
script = io.NewBufBinWriter() script = io.NewBufBinWriter()
emit.AppCallWithOperationAndArgs(script.BinWriter, hash.Hash160(avm), "Put", "testkey", "testvalue") emit.AppCallWithOperationAndArgs(script.BinWriter, hash.Hash160(avm), "putValue", "testkey", "testvalue")
txInv := transaction.New(testchain.Network(), script.Bytes(), 1*native.GASFactor) txInv := transaction.New(testchain.Network(), script.Bytes(), 1*native.GASFactor)
txInv.Nonce = getNextNonce() txInv.Nonce = getNextNonce()

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"math" "math"
"math/big" "math/big"
"strings"
"unicode/utf8" "unicode/utf8"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
@ -492,16 +493,49 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
if err != nil { if err != nil {
return err return err
} }
name := string(bs)
if strings.HasPrefix(name, "_") {
return errors.New("invalid method name (starts with '_')")
}
md := cs.Manifest.ABI.GetMethod(name)
if md == nil {
return fmt.Errorf("method '%s' not found", name)
}
curr, err := ic.DAO.GetContractState(v.GetCurrentScriptHash()) curr, err := ic.DAO.GetContractState(v.GetCurrentScriptHash())
if err == nil { if err == nil {
if !curr.Manifest.CanCall(&cs.Manifest, string(bs)) { if !curr.Manifest.CanCall(&cs.Manifest, string(bs)) {
return errors.New("disallowed method call") return errors.New("disallowed method call")
} }
} }
arr, ok := args.Value().([]stackitem.Item)
if !ok {
return errors.New("second argument must be an array")
}
if len(arr) != len(md.Parameters) {
return fmt.Errorf("invalid argument count: %d (expected %d)", len(arr), len(md.Parameters))
}
ic.Invocations[u]++ ic.Invocations[u]++
v.LoadScriptWithHash(cs.Script, u, v.Context().GetCallFlags()&f) v.LoadScriptWithHash(cs.Script, u, v.Context().GetCallFlags()&f)
v.Estack().PushVal(args) var isNative bool
v.Estack().PushVal(method) for i := range ic.Natives {
if ic.Natives[i].Metadata().Hash.Equals(u) {
isNative = true
break
}
}
if isNative {
v.Estack().PushVal(args)
v.Estack().PushVal(method)
} else {
for i := len(arr) - 1; i >= 0; i-- {
v.Estack().PushVal(arr[i])
}
// use Jump not Call here because context was loaded in LoadScript above.
v.Jump(v.Context(), md.Offset)
}
return nil return nil
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -323,6 +324,109 @@ func TestBlockchainGetContractState(t *testing.T) {
}) })
} }
func getTestContractState() *state.Contract {
script := []byte{
byte(opcode.ABORT), // abort if no offset was provided
byte(opcode.ADD), byte(opcode.RET),
byte(opcode.PUSH7), byte(opcode.RET),
}
h := hash.Hash160(script)
m := manifest.NewManifest(h)
m.ABI.Methods = []manifest.Method{
{
Name: "add",
Offset: 1,
Parameters: []manifest.Parameter{
manifest.NewParameter("addend1", smartcontract.IntegerType),
manifest.NewParameter("addend2", smartcontract.IntegerType),
},
ReturnType: smartcontract.IntegerType,
},
{
Name: "ret7",
Offset: 3,
Parameters: []manifest.Parameter{},
ReturnType: smartcontract.IntegerType,
},
}
return &state.Contract{
Script: script,
Manifest: *m,
ID: 42,
}
}
func TestContractCall(t *testing.T) {
v, ic, bc := createVM(t)
defer bc.Close()
cs := getTestContractState()
require.NoError(t, ic.DAO.PutContractState(cs))
currScript := []byte{byte(opcode.NOP)}
initVM := func(v *vm.VM) {
v.Istack().Clear()
v.Estack().Clear()
v.Load(currScript)
v.Estack().PushVal(42) // canary
}
h := cs.Manifest.ABI.Hash
m := manifest.NewManifest(hash.Hash160(currScript))
perm := manifest.NewPermission(manifest.PermissionHash, h)
perm.Methods.Add("add")
m.Permissions = append(m.Permissions, *perm)
require.NoError(t, ic.DAO.PutContractState(&state.Contract{
Script: currScript,
Manifest: *m,
ID: 123,
}))
addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)})
t.Run("Good", func(t *testing.T) {
initVM(v)
v.Estack().PushVal(addArgs)
v.Estack().PushVal("add")
v.Estack().PushVal(h.BytesBE())
require.NoError(t, contractCall(ic, v))
require.NoError(t, v.Run())
require.Equal(t, 2, v.Estack().Len())
require.Equal(t, big.NewInt(3), v.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), v.Estack().Pop().Value())
})
t.Run("CallExInvalidFlag", func(t *testing.T) {
initVM(v)
v.Estack().PushVal(byte(0xFF))
v.Estack().PushVal(addArgs)
v.Estack().PushVal("add")
v.Estack().PushVal(h.BytesBE())
require.Error(t, contractCallEx(ic, v))
})
runInvalid := func(args ...interface{}) func(t *testing.T) {
return func(t *testing.T) {
initVM(v)
for i := range args {
v.Estack().PushVal(args[i])
}
require.Error(t, contractCall(ic, v))
}
}
t.Run("Invalid", func(t *testing.T) {
t.Run("Hash", runInvalid(addArgs, "add", h.BytesBE()[1:]))
t.Run("MissingHash", runInvalid(addArgs, "add", util.Uint160{}.BytesBE()))
t.Run("Method", runInvalid(addArgs, stackitem.NewInterop("add"), h.BytesBE()))
t.Run("MissingMethod", runInvalid(addArgs, "sub", h.BytesBE()))
t.Run("DisallowedMethod", runInvalid(stackitem.NewArray(nil), "ret7", h.BytesBE()))
t.Run("Arguments", runInvalid(1, "add", h.BytesBE()))
t.Run("NotEnoughArguments", runInvalid(
stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE()))
})
}
func TestContractCreate(t *testing.T) { func TestContractCreate(t *testing.T) {
v, cs, ic, bc := createVMAndContractState(t) v, cs, ic, bc := createVMAndContractState(t)
v.GasLimit = -1 v.GasLimit = -1

View file

@ -95,7 +95,10 @@ func TestNativeContract_Invoke(t *testing.T) {
tn := newTestNative() tn := newTestNative()
chain.registerNative(tn) chain.registerNative(tn)
err := chain.dao.PutContractState(&state.Contract{Script: tn.meta.Script}) err := chain.dao.PutContractState(&state.Contract{
Script: tn.meta.Script,
Manifest: tn.meta.Manifest,
})
require.NoError(t, err) require.NoError(t, err)
w := io.NewBufBinWriter() w := io.NewBufBinWriter()

View file

@ -13,6 +13,6 @@ package engine
// dynamic calls in Neo (contracts should have a special property declared // dynamic calls in Neo (contracts should have a special property declared
// and paid for to be able to use dynamic calls). This function uses // and paid for to be able to use dynamic calls). This function uses
// `System.Contract.Call` syscall. // `System.Contract.Call` syscall.
func AppCall(scriptHash []byte, args ...interface{}) interface{} { func AppCall(scriptHash []byte, method string, args ...interface{}) interface{} {
return nil return nil
} }

View file

@ -51,8 +51,8 @@ type rpcTestCase struct {
check func(t *testing.T, e *executor, result interface{}) check func(t *testing.T, e *executor, result interface{})
} }
const testContractHash = "402da558b87b5e54b59dc242c788bb4dd4cd906c" const testContractHash = "6e2d823c81589871590653a100c7e9bdf9c94344"
const deploymentTxHash = "2afd69cc80ebe900a060450e8628b57063f3ec93ca5fc7f94582be4a4f3a041f" const deploymentTxHash = "3b434127495a6dd0e786a2e0f04696009cd6e6e5f9b930f0e79356638532096c"
var rpcTestCases = map[string][]rpcTestCase{ var rpcTestCases = map[string][]rpcTestCase{
"getapplicationlog": { "getapplicationlog": {

View file

@ -11,83 +11,91 @@ const (
) )
func Main(operation string, args []interface{}) interface{} { func Main(operation string, args []interface{}) interface{} {
runtime.Notify("contract call", operation, args) panic("invoking via Main is no longer supported") // catch possible bugs
switch operation { }
case "Put":
ctx := storage.GetContext() func Init() bool {
storage.Put(ctx, args[0].([]byte), args[1].([]byte)) ctx := storage.GetContext()
return true h := runtime.GetExecutingScriptHash()
case "totalSupply": amount := totalSupply
return totalSupply storage.Put(ctx, h, amount)
case "decimals": runtime.Notify("transfer", []byte{}, h, amount)
return decimals return true
case "name": }
return "Rubl"
case "symbol": func Transfer(from, to []byte, amount int) bool {
return "RUB" ctx := storage.GetContext()
case "balanceOf": if len(from) != 20 {
ctx := storage.GetContext() runtime.Log("invalid 'from' address")
addr := args[0].([]byte) return false
if len(addr) != 20 { }
runtime.Log("invalid address") if len(to) != 20 {
return false runtime.Log("invalid 'to' address")
} return false
var amount int }
val := storage.Get(ctx, addr) if amount < 0 {
if val != nil { runtime.Log("invalid amount")
amount = val.(int) return false
} }
runtime.Notify("balanceOf", addr, amount)
return amount var fromBalance int
case "transfer": val := storage.Get(ctx, from)
ctx := storage.GetContext() if val != nil {
from := args[0].([]byte) fromBalance = val.(int)
if len(from) != 20 { }
runtime.Log("invalid 'from' address") if fromBalance < amount {
return false runtime.Log("insufficient funds")
} return false
to := args[1].([]byte) }
if len(to) != 20 { fromBalance -= amount
runtime.Log("invalid 'to' address") storage.Put(ctx, from, fromBalance)
return false
} var toBalance int
amount := args[2].(int) val = storage.Get(ctx, to)
if amount < 0 { if val != nil {
runtime.Log("invalid amount") toBalance = val.(int)
return false }
} toBalance += amount
storage.Put(ctx, to, toBalance)
var fromBalance int
val := storage.Get(ctx, from) runtime.Notify("transfer", from, to, amount)
if val != nil {
fromBalance = val.(int) return true
} }
if fromBalance < amount {
runtime.Log("insufficient funds") func BalanceOf(addr []byte) int {
return false ctx := storage.GetContext()
} if len(addr) != 20 {
fromBalance -= amount runtime.Log("invalid address")
storage.Put(ctx, from, fromBalance) return 0
}
var toBalance int var amount int
val = storage.Get(ctx, to) val := storage.Get(ctx, addr)
if val != nil { if val != nil {
toBalance = val.(int) amount = val.(int)
} }
toBalance += amount runtime.Notify("balanceOf", addr, amount)
storage.Put(ctx, to, toBalance) return amount
}
runtime.Notify("transfer", from, to, amount)
func Name() string {
return true return "Rubl"
case "init": }
ctx := storage.GetContext()
h := runtime.GetExecutingScriptHash() func Symbol() string {
amount := totalSupply return "RUB"
storage.Put(ctx, h, amount) }
runtime.Notify("transfer", []byte{}, h, amount)
return true func Decimals() int {
default: return decimals
panic("invalid operation") }
}
func TotalSupply() int {
return totalSupply
}
func PutValue(key []byte, value []byte) bool {
ctx := storage.GetContext()
storage.Put(ctx, key, value)
return true
} }

Binary file not shown.

View file

@ -68,6 +68,16 @@ func DefaultManifest(h util.Uint160) *Manifest {
return m return m
} }
// GetMethod returns methods with the specified name.
func (a *ABI) GetMethod(name string) *Method {
for i := range a.Methods {
if a.Methods[i].Name == name {
return &a.Methods[i]
}
}
return nil
}
// CanCall returns true is current contract is allowed to call // CanCall returns true is current contract is allowed to call
// method of another contract. // method of another contract.
func (m *Manifest) CanCall(toCall *Manifest, method string) bool { func (m *Manifest) CanCall(toCall *Manifest, method string) bool {

View file

@ -1232,7 +1232,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
} }
if cond { if cond {
v.jump(ctx, offset) v.Jump(ctx, offset)
} }
case opcode.CALL, opcode.CALLL: case opcode.CALL, opcode.CALLL:
@ -1243,7 +1243,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
v.istack.PushVal(newCtx) v.istack.PushVal(newCtx)
offset := v.getJumpOffset(newCtx, parameter) offset := v.getJumpOffset(newCtx, parameter)
v.jump(newCtx, offset) v.Jump(newCtx, offset)
case opcode.CALLA: case opcode.CALLA:
ptr := v.estack.Pop().Item().(*stackitem.Pointer) ptr := v.estack.Pop().Item().(*stackitem.Pointer)
@ -1255,7 +1255,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
newCtx.local = nil newCtx.local = nil
newCtx.arguments = nil newCtx.arguments = nil
v.istack.PushVal(newCtx) v.istack.PushVal(newCtx)
v.jump(newCtx, ptr.Position()) v.Jump(newCtx, ptr.Position())
case opcode.SYSCALL: case opcode.SYSCALL:
interopID := GetInteropID(parameter) interopID := GetInteropID(parameter)
@ -1404,7 +1404,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
} else { } else {
ctx.tryStack.Pop() ctx.tryStack.Pop()
} }
v.jump(ctx, eOffset) v.Jump(ctx, eOffset)
case opcode.ENDFINALLY: case opcode.ENDFINALLY:
if v.uncaughtException != nil { if v.uncaughtException != nil {
@ -1412,7 +1412,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
return return
} }
eCtx := ctx.tryStack.Pop().Value().(*exceptionHandlingContext) eCtx := ctx.tryStack.Pop().Value().(*exceptionHandlingContext)
v.jump(ctx, eCtx.EndOffset) v.Jump(ctx, eCtx.EndOffset)
default: default:
panic(fmt.Sprintf("unknown opcode %s", op.String())) panic(fmt.Sprintf("unknown opcode %s", op.String()))
@ -1468,8 +1468,8 @@ func (v *VM) throw(item stackitem.Item) {
v.handleException() v.handleException()
} }
// jump performs jump to the offset. // Jump performs jump to the offset.
func (v *VM) jump(ctx *Context, offset int) { func (v *VM) Jump(ctx *Context, offset int) {
ctx.nextip = offset ctx.nextip = offset
} }
@ -1526,10 +1526,10 @@ func (v *VM) handleException() {
ectx.State = eCatch ectx.State = eCatch
v.estack.PushVal(v.uncaughtException) v.estack.PushVal(v.uncaughtException)
v.uncaughtException = nil v.uncaughtException = nil
v.jump(ictx, ectx.CatchOffset) v.Jump(ictx, ectx.CatchOffset)
} else { } else {
ectx.State = eFinally ectx.State = eFinally
v.jump(ictx, ectx.FinallyOffset) v.Jump(ictx, ectx.FinallyOffset)
} }
return return
} }