diff --git a/pkg/compiler/interop_test.go b/pkg/compiler/interop_test.go index 4c2ef69ac..39435c243 100644 --- a/pkg/compiler/interop_test.go +++ b/pkg/compiler/interop_test.go @@ -63,11 +63,11 @@ func TestAppCall(t *testing.T) { require.NoError(t, err) ih := hash.Hash160(inner) - getScript := func(u util.Uint160) []byte { + getScript := func(u util.Uint160) ([]byte, bool) { if u.Equals(ih) { - return inner + return inner, true } - return nil + return nil, false } t.Run("valid script", func(t *testing.T) { diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 0bf1926cb..d3c9fc710 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1929,12 +1929,13 @@ func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([ // up for current blockchain. func (bc *Blockchain) spawnVMWithInterops(interopCtx *interopContext) *vm.VM { vm := vm.New() - vm.SetScriptGetter(func(hash util.Uint160) []byte { + vm.SetScriptGetter(func(hash util.Uint160) ([]byte, bool) { cs, err := interopCtx.dao.GetContractState(hash) if err != nil { - return nil + return nil, false } - return cs.Script + hasDynamicInvoke := (cs.Properties & smartcontract.HasDynamicInvoke) != 0 + return cs.Script, hasDynamicInvoke }) vm.RegisterInteropGetter(interopCtx.getSystemInterop) vm.RegisterInteropGetter(interopCtx.getNeoInterop) diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 43ea9c6ae..9f6d19ba0 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -35,6 +35,9 @@ type Context struct { // Script hash of the prog. scriptHash util.Uint160 + + // Whether it's allowed to make dynamic calls from this context. + hasDynamicInvoke bool } var errNoInstParam = errors.New("failed to read instruction parameter") diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index a92a103a7..c9249e0b3 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -174,7 +174,7 @@ func testFile(t *testing.T, filename string) { }) } -func getScript(scripts []map[string]vmUTScript) func(util.Uint160) []byte { +func getScript(scripts []map[string]vmUTScript) func(util.Uint160) ([]byte, bool) { store := make(map[util.Uint160][]byte) for i := range scripts { for _, v := range scripts[i] { @@ -182,7 +182,7 @@ func getScript(scripts []map[string]vmUTScript) func(util.Uint160) []byte { } } - return func(a util.Uint160) []byte { return store[a] } + return func(a util.Uint160) ([]byte, bool) { return store[a], true } } func compareItems(t *testing.T, a, b StackItem) { diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 4ec5b603f..411cf523d 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -68,7 +68,7 @@ type VM struct { getPrice func(*VM, opcode.Opcode, []byte) util.Fixed8 // callback to get scripts. - getScript func(util.Uint160) []byte + getScript func(util.Uint160) ([]byte, bool) istack *Stack // invocation stack. estack *Stack // execution stack. @@ -266,9 +266,11 @@ func (v *VM) LoadScript(b []byte) { // loadScriptWithHash if similar to the LoadScript method, but it also loads // given script hash directly into the Context to avoid its recalculations. It's // up to user of this function to make sure the script and hash match each other. -func (v *VM) loadScriptWithHash(b []byte, hash util.Uint160) { +func (v *VM) loadScriptWithHash(b []byte, hash util.Uint160, hasDynamicInvoke bool) { v.LoadScript(b) - v.istack.Top().Value().(*Context).scriptHash = hash + ctx := v.Context() + ctx.scriptHash = hash + ctx.hasDynamicInvoke = hasDynamicInvoke } // Context returns the current executed context. Nil if there is no context, @@ -472,7 +474,7 @@ func (v *VM) SetCheckedHash(h []byte) { } // SetScriptGetter sets the script getter for CALL instructions. -func (v *VM) SetScriptGetter(gs func(util.Uint160) []byte) { +func (v *VM) SetScriptGetter(gs func(util.Uint160) ([]byte, bool)) { v.getScript = gs } @@ -518,6 +520,27 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } } + switch op { + case opcode.APPCALL, opcode.TAILCALL: + isZero := true + for i := range parameter { + if parameter[i] != 0 { + isZero = false + break + } + } + if !isZero { + break + } + + parameter = v.estack.Pop().Bytes() + fallthrough + case opcode.CALLED, opcode.CALLEDT: + if !ctx.hasDynamicInvoke { + panic("contract is not allowed to make dynamic invocations") + } + } + if op >= opcode.PUSHBYTES1 && op <= opcode.PUSHBYTES75 { v.estack.PushVal(parameter) return @@ -1150,7 +1173,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro panic(err) } - script := v.getScript(hash) + script, hasDynamicInvoke := v.getScript(hash) if script == nil { panic("could not find script") } @@ -1159,7 +1182,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro _ = v.istack.Pop() } - v.loadScriptWithHash(script, hash) + v.loadScriptWithHash(script, hash, hasDynamicInvoke) case opcode.RET: oldCtx := v.istack.Pop().Value().(*Context) @@ -1354,12 +1377,13 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if err != nil { panic(err) } - script := v.getScript(hash) + script, hasDynamicInvoke := v.getScript(hash) if script == nil { panic(fmt.Sprintf("could not find script %s", hash)) } newCtx = NewContext(script) newCtx.scriptHash = hash + newCtx.hasDynamicInvoke = hasDynamicInvoke } newCtx.rvcount = rvcount newCtx.estack = NewStack("evaluation") diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index c8488d189..6978cec08 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -1670,16 +1670,16 @@ func TestSIGNByteArray(t *testing.T) { func TestAppCall(t *testing.T) { prog := []byte{byte(opcode.APPCALL)} - hash := util.Uint160{} + hash := util.Uint160{1, 2} prog = append(prog, hash.BytesBE()...) prog = append(prog, byte(opcode.RET)) vm := load(prog) - vm.SetScriptGetter(func(in util.Uint160) []byte { + vm.SetScriptGetter(func(in util.Uint160) ([]byte, bool) { if in.Equals(hash) { - return makeProgram(opcode.DEPTH) + return makeProgram(opcode.DEPTH), true } - return nil + return nil, false }) vm.estack.PushVal(2) @@ -1688,6 +1688,49 @@ func TestAppCall(t *testing.T) { assert.Equal(t, int64(1), elem.BigInt().Int64()) } +func TestAppCallDynamicBad(t *testing.T) { + prog := []byte{byte(opcode.APPCALL)} + hash := util.Uint160{} + prog = append(prog, hash.BytesBE()...) + prog = append(prog, byte(opcode.RET)) + + vm := load(prog) + vm.SetScriptGetter(func(in util.Uint160) ([]byte, bool) { + if in.Equals(hash) { + return makeProgram(opcode.DEPTH), true + } + return nil, false + }) + vm.estack.PushVal(2) + vm.estack.PushVal(hash.BytesBE()) + + checkVMFailed(t, vm) +} + +func TestAppCallDynamicGood(t *testing.T) { + prog := []byte{byte(opcode.APPCALL)} + zeroHash := util.Uint160{} + hash := util.Uint160{1, 2, 3} + prog = append(prog, zeroHash.BytesBE()...) + prog = append(prog, byte(opcode.RET)) + + vm := load(prog) + vm.SetScriptGetter(func(in util.Uint160) ([]byte, bool) { + if in.Equals(hash) { + return makeProgram(opcode.DEPTH), true + } + return nil, false + }) + vm.estack.PushVal(42) + vm.estack.PushVal(42) + vm.estack.PushVal(hash.BytesBE()) + vm.Context().hasDynamicInvoke = true + + runVM(t, vm) + elem := vm.estack.Pop() // depth should be 2 + assert.Equal(t, int64(2), elem.BigInt().Int64()) +} + func TestSimpleCall(t *testing.T) { progStr := "52c56b525a7c616516006c766b00527ac46203006c766b00c3616c756653c56b6c766b00527ac46c766b51527ac46203006c766b00c36c766b51c393616c7566" result := 12