diff --git a/docs/compiler.md b/docs/compiler.md index 56b95a90d..79069be64 100644 --- a/docs/compiler.md +++ b/docs/compiler.md @@ -19,11 +19,11 @@ a dialect of Go rather than a complete port of the language: * goroutines, channels and garbage collection are not supported and will never be because emulating that aspects of Go runtime on top of Neo VM is close to impossible - * even though `panic()` is supported, `recover()` is not, `panic` shuts the - VM down - * lambdas are not supported (#939) - * it's not possible to rename imported interop packages, they won't work this - way (#397, #913) + * `defer` and `recover` are supported except for cases where panic occurs in + `return` statement, because this complicates implementation and imposes runtime + overhead for all contracts. This can easily be mitigated by first storing values + in variables and returning the result. + * lambdas are supported, but closures are not. ## VM API (interop layer) Compiler translates interop function calls into NEO VM syscalls or (for custom diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 62ca08ce6..730a9bd89 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -14,7 +14,7 @@ import ( var ( // Go language builtin functions. - goBuiltins = []string{"len", "append", "panic", "make", "copy"} + goBuiltins = []string{"len", "append", "panic", "make", "copy", "recover"} // Custom builtin utility functions. customBuiltins = []string{ "FromAddress", "Equals", @@ -40,23 +40,30 @@ func (c *codegen) getIdentName(pkg string, name string) string { // and returns number of variables initialized and // true if any init functions were encountered. func (c *codegen) traverseGlobals() (int, bool) { + var hasDefer bool var n int var hasInit bool c.ForEachFile(func(f *ast.File, _ *types.Package) { n += countGlobals(f) - if !hasInit { + if !hasInit || !hasDefer { ast.Inspect(f, func(node ast.Node) bool { - n, ok := node.(*ast.FuncDecl) - if ok { + switch n := node.(type) { + case *ast.FuncDecl: if isInitFunc(n) { hasInit = true } + return !hasDefer + case *ast.DeferStmt: + hasDefer = true return false } return true }) } }) + if hasDefer { + n++ + } if n != 0 || hasInit { if n > 255 { c.prog.BinWriter.Err = errors.New("too many global variables") @@ -83,6 +90,11 @@ func (c *codegen) traverseGlobals() (int, bool) { // encountered after will be recognized as globals. c.scope = nil }) + // store auxiliary variables after all others. + if hasDefer { + c.exceptionIndex = len(c.globals) + c.globals[""] = c.exceptionIndex + } } return n, hasInit } @@ -92,7 +104,7 @@ func (c *codegen) traverseGlobals() (int, bool) { func countGlobals(f ast.Node) (i int) { ast.Inspect(f, func(node ast.Node) bool { switch n := node.(type) { - // Skip all function declarations. + // Skip all function declarations if we have already encountered `defer`. case *ast.FuncDecl: return false // After skipping all funcDecls we are sure that each value spec diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 4b7c3f2d2..a5fe85d1c 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -12,7 +12,6 @@ import ( "sort" "strings" - "github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames" "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm" @@ -78,6 +77,9 @@ type codegen struct { // packages contains packages in the order they were loaded. packages []string + // exceptionIndex is the index of static slot where exception is stored. + exceptionIndex int + // documents contains paths to all files used by the program. documents []string // docIndex maps file path to an index in documents array. @@ -217,6 +219,11 @@ func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) { // emitLoadVar loads specified variable to the evaluation stack. func (c *codegen) emitLoadVar(pkg string, name string) { t, i := c.getVarIndex(pkg, name) + c.emitLoadByIndex(t, i) +} + +// emitLoadByIndex loads specified variable type with index i. +func (c *codegen) emitLoadByIndex(t varType, i int) { base, _ := getBaseOpcode(t) if i < 7 { emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i)) @@ -232,6 +239,11 @@ func (c *codegen) emitStoreVar(pkg string, name string) { return } t, i := c.getVarIndex(pkg, name) + c.emitStoreByIndex(t, i) +} + +// emitLoadByIndex stores top value in the specified variable type with index i. +func (c *codegen) emitStoreByIndex(t varType, i int) { _, base := getBaseOpcode(t) if i < 7 { emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i)) @@ -553,6 +565,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } } + c.processDefers() + c.saveSequencePoint(n) emit.Opcode(c.prog.BinWriter, opcode.RET) return nil @@ -565,6 +579,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { lElse := c.newLabel() lElseEnd := c.newLabel() + if n.Init != nil { + ast.Walk(c, n.Init) + } if n.Cond != nil { c.emitBoolExpr(n.Cond, true, false, lElse) } @@ -708,6 +725,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { numArgs = len(n.Args) isBuiltin bool isFunc bool + isLiteral bool ) switch fun := n.Fun.(type) { @@ -746,6 +764,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { ast.Walk(c, n.Args[0]) c.emitConvert(stackitem.BufferT) return nil + case *ast.FuncLit: + isLiteral = true } c.saveSequencePoint(n) @@ -798,6 +818,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { c.emitLoadVar("", name) emit.Opcode(c.prog.BinWriter, opcode.CALLA) } + case isLiteral: + ast.Walk(c, n.Fun) + emit.Opcode(c.prog.BinWriter, opcode.CALLA) case isSyscall(f): c.convertSyscall(n, f.pkg.Name(), f.name) default: @@ -820,6 +843,20 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil + case *ast.DeferStmt: + catch := c.newLabel() + finally := c.newLabel() + param := make([]byte, 8) + binary.LittleEndian.PutUint16(param[0:], catch) + binary.LittleEndian.PutUint16(param[4:], finally) + emit.Instruction(c.prog.BinWriter, opcode.TRYL, param) + c.scope.deferStack = append(c.scope.deferStack, deferInfo{ + catchLabel: catch, + finallyLabel: finally, + expr: n.Call, + }) + return nil + case *ast.SelectorExpr: typ := c.typeOf(n.X) if typ == nil { @@ -1081,6 +1118,51 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return c } +// processDefers emits code for `defer` statements. +// TRY-related opcodes handle exception as follows: +// 1. CATCH block is executed only if exception has occured. +// 2. FINALLY block is always executed, but after catch block. +// Go `defer` statements are a bit different: +// 1. `defer` is always executed irregardless of whether an exception has occured. +// 2. `recover` can or can not handle a possible exception. +// Thus we use the following approach: +// 1. Throwed exception is saved in a static field X, static fields Y and is set to true. +// 2. CATCH and FINALLY blocks are the same, and both contain the same CALLs. +// 3. In CATCH block we set Y to true and emit default return values if it is the last defer. +// 4. Execute FINALLY block only if Y is false. +func (c *codegen) processDefers() { + for i := len(c.scope.deferStack) - 1; i >= 0; i-- { + stmt := c.scope.deferStack[i] + after := c.newLabel() + emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after) + + c.setLabel(stmt.catchLabel) + c.emitStoreByIndex(varGlobal, c.exceptionIndex) + emit.Int(c.prog.BinWriter, 1) + c.emitStoreByIndex(varLocal, c.scope.finallyProcessedIndex) + ast.Walk(c, stmt.expr) + if i == 0 { + // After panic, default values must be returns, except for named returns, + // which we don't support here for now. + for i := len(c.scope.decl.Type.Results.List) - 1; i >= 0; i-- { + c.emitDefault(c.typeOf(c.scope.decl.Type.Results.List[i].Type)) + } + } + emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after) + + c.setLabel(stmt.finallyLabel) + before := c.newLabel() + c.emitLoadByIndex(varLocal, c.scope.finallyProcessedIndex) + emit.Jmp(c.prog.BinWriter, opcode.JMPIFL, before) + ast.Walk(c, stmt.expr) + c.setLabel(before) + emit.Int(c.prog.BinWriter, 0) + c.emitStoreByIndex(varLocal, c.scope.finallyProcessedIndex) + emit.Opcode(c.prog.BinWriter, opcode.ENDFINALLY) + c.setLabel(after) + } +} + func (c *codegen) rangeLoadKey() { emit.Int(c.prog.BinWriter, 2) emit.Opcode(c.prog.BinWriter, opcode.PICK) // load keys @@ -1428,17 +1510,11 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) { } } case "panic": - arg := expr.Args[0] - if isExprNil(arg) { - emit.Opcode(c.prog.BinWriter, opcode.DROP) - emit.Opcode(c.prog.BinWriter, opcode.THROW) - } else if isString(c.typeInfo.Types[arg].Type) { - ast.Walk(c, arg) - emit.Syscall(c.prog.BinWriter, interopnames.SystemRuntimeLog) - emit.Opcode(c.prog.BinWriter, opcode.THROW) - } else { - c.prog.Err = errors.New("panic should have string or nil argument") - } + emit.Opcode(c.prog.BinWriter, opcode.THROW) + case "recover": + c.emitLoadByIndex(varGlobal, c.exceptionIndex) + emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL) + c.emitStoreByIndex(varGlobal, c.exceptionIndex) case "ToInteger", "ToByteArray", "ToBool": typ := stackitem.IntegerT switch name { @@ -1482,8 +1558,6 @@ func transformArgs(fun ast.Expr, args []ast.Expr) []ast.Expr { } case *ast.Ident: switch f.Name { - case "panic": - return args[1:] case "make", "copy": return nil } @@ -1778,27 +1852,32 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) { case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, opcode.JMPEQ, opcode.JMPNE, opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: + case opcode.TRYL: + nextIP := ctx.NextIP() + catchArg := b[nextIP-8:] + _, err := c.replaceLabelWithOffset(ctx.IP(), catchArg) + if err != nil { + return nil, err + } + finallyArg := b[nextIP-4:] + _, err = c.replaceLabelWithOffset(ctx.IP(), finallyArg) + if err != nil { + return nil, err + } case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, - opcode.CALLL, opcode.PUSHA: + opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL: // we can't use arg returned by ctx.Next() because it is copied nextIP := ctx.NextIP() arg := b[nextIP-4:] - - index := binary.LittleEndian.Uint16(arg) - if int(index) > len(c.l) { - return nil, fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l)) - } - offset := c.l[index] - nextIP + 5 - if offset > math.MaxInt32 || offset < math.MinInt32 { - return nil, fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)", - nextIP-5, offset, math.MaxInt32, math.MinInt32) + offset, err := c.replaceLabelWithOffset(ctx.IP(), arg) + if err != nil { + return nil, err } if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 { offsets = append(offsets, ctx.IP()) } - binary.LittleEndian.PutUint32(arg, uint32(offset)) } } // Correct function ip range. @@ -1820,6 +1899,20 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) { return shortenJumps(b, offsets), nil } +func (c *codegen) replaceLabelWithOffset(ip int, arg []byte) (int, error) { + index := binary.LittleEndian.Uint16(arg) + if int(index) > len(c.l) { + return 0, fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l)) + } + offset := c.l[index] - ip + if offset > math.MaxInt32 || offset < math.MinInt32 { + return 0, fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)", + ip, offset, math.MaxInt32, math.MinInt32) + } + binary.LittleEndian.PutUint32(arg, uint32(offset)) + return offset, nil +} + // longToShortRemoveCount is a difference between short and long instruction sizes in bytes. const longToShortRemoveCount = 3 @@ -1844,18 +1937,34 @@ func shortenJumps(b []byte, offsets []int) []byte { switch op { case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, opcode.JMPEQ, opcode.JMPNE, - opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: + opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT, opcode.ENDTRY: offset := int(int8(b[nextIP-1])) offset += calcOffsetCorrection(ip, ip+offset, offsets) b[nextIP-1] = byte(offset) + case opcode.TRY: + catchOffset := int(int8(b[nextIP-2])) + catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets) + b[nextIP-1] = byte(catchOffset) + finallyOffset := int(int8(b[nextIP-1])) + finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) + b[nextIP-1] = byte(finallyOffset) case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, - opcode.CALLL, opcode.PUSHA: + opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL: arg := b[nextIP-4:] offset := int(int32(binary.LittleEndian.Uint32(arg))) offset += calcOffsetCorrection(ip, ip+offset, offsets) binary.LittleEndian.PutUint32(arg, uint32(offset)) + case opcode.TRYL: + arg := b[nextIP-8:] + catchOffset := int(int32(binary.LittleEndian.Uint32(arg))) + catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets) + binary.LittleEndian.PutUint32(arg, uint32(catchOffset)) + arg = b[nextIP-4:] + finallyOffset := int(int32(binary.LittleEndian.Uint32(arg))) + finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) + binary.LittleEndian.PutUint32(arg, uint32(finallyOffset)) } } @@ -1939,6 +2048,8 @@ func toShortForm(op opcode.Opcode) opcode.Opcode { return opcode.JMPLT case opcode.CALLL: return opcode.CALL + case opcode.ENDTRYL: + return opcode.ENDTRY default: panic(fmt.Errorf("invalid opcode: %s", op)) } diff --git a/pkg/compiler/defer_test.go b/pkg/compiler/defer_test.go new file mode 100644 index 000000000..4a2f8f3b0 --- /dev/null +++ b/pkg/compiler/defer_test.go @@ -0,0 +1,139 @@ +package compiler_test + +import ( + "math/big" + "testing" +) + +func TestDefer(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + src := `package main + var a int + func Main() int { + return h() + a + } + func h() int { + defer f() + return 1 + } + func f() { a += 2 }` + eval(t, src, big.NewInt(3)) + }) + t.Run("ValueUnchanged", func(t *testing.T) { + src := `package main + var a int + func Main() int { + defer f() + a = 3 + return a + } + func f() { a += 2 }` + eval(t, src, big.NewInt(3)) + }) + t.Run("Function", func(t *testing.T) { + src := `package main + var a int + func Main() int { + return h() + a + } + func h() int { + defer f() + a = 3 + return g() + } + func g() int { + a++ + return a + } + func f() { a += 2 }` + eval(t, src, big.NewInt(10)) + }) + t.Run("MultipleDefers", func(t *testing.T) { + src := `package main + var a int + func Main() int { + return h() + a + } + func h() int { + defer f() + defer g() + a = 3 + return a + } + func g() { a *= 2 } + func f() { a += 2 }` + eval(t, src, big.NewInt(11)) + }) + t.Run("FunctionLiteral", func(t *testing.T) { + src := `package main + var a int + func Main() int { + return h() + a + } + func h() int { + defer func() { + a = 10 + }() + a = 3 + return a + }` + eval(t, src, big.NewInt(13)) + }) +} + +func TestRecover(t *testing.T) { + t.Run("Panic", func(t *testing.T) { + src := `package foo + var a int + func Main() int { + return h() + a + } + func h() int { + defer func() { + if r := recover(); r != nil { + a = 3 + } else { + a = 4 + } + }() + a = 1 + panic("msg") + return a + }` + eval(t, src, big.NewInt(3)) + }) + t.Run("NoPanic", func(t *testing.T) { + src := `package foo + var a int + func Main() int { + return h() + a + } + func h() int { + defer func() { + if r := recover(); r != nil { + a = 3 + } else { + a = 4 + } + }() + a = 1 + return a + }` + eval(t, src, big.NewInt(5)) + }) + t.Run("PanicInDefer", func(t *testing.T) { + src := `package foo + var a int + func Main() int { + return h() + a + } + func h() int { + defer func() { a += 2; _ = recover() }() + defer func() { a *= 3; _ = recover(); panic("again") }() + a = 1 + panic("msg") + return a + }` + eval(t, src, big.NewInt(5)) + }) +} diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index aba84a5b9..2eca590fe 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -30,6 +30,12 @@ type funcScope struct { // Variables together with it's type in neo-vm. variables []string + // deferStack is a stack containing encountered `defer` statements. + deferStack []deferInfo + // finallyProcessed is a index of static slot with boolean flag determining + // if `defer` statement was already processed. + finallyProcessedIndex int + // Local variables vars varScope @@ -45,6 +51,12 @@ type funcScope struct { i int } +type deferInfo struct { + catchLabel uint16 + finallyLabel uint16 + expr *ast.CallExpr +} + func (c *codegen) newFuncScope(decl *ast.FuncDecl, label uint16) *funcScope { var name string if decl.Name != nil { @@ -115,6 +127,7 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool { func (c *funcScope) countLocals() int { size := 0 + hasDefer := false ast.Inspect(c.decl, func(n ast.Node) bool { switch n := n.(type) { case *ast.FuncType: @@ -124,8 +137,11 @@ func (c *funcScope) countLocals() int { } case *ast.AssignStmt: if n.Tok == token.DEFINE { - size += len(n.Rhs) + size += len(n.Lhs) } + case *ast.DeferStmt: + hasDefer = true + return false case *ast.ReturnStmt, *ast.IfStmt: size++ // This handles the inline GenDecl like "var x = 2" @@ -143,6 +159,10 @@ func (c *funcScope) countLocals() int { } return true }) + if hasDefer { + c.finallyProcessedIndex = size + size++ + } return size } diff --git a/pkg/compiler/global_test.go b/pkg/compiler/global_test.go index 2be4caf30..e4f29f89b 100644 --- a/pkg/compiler/global_test.go +++ b/pkg/compiler/global_test.go @@ -38,6 +38,18 @@ func TestMultiDeclaration(t *testing.T) { eval(t, src, big.NewInt(6)) } +func TestCountLocal(t *testing.T) { + src := `package foo + func Main() int { + a, b, c, d := f() + return a + b + c + d + } + func f() (int, int, int, int) { + return 1, 2, 3, 4 + }` + eval(t, src, big.NewInt(10)) +} + func TestMultiDeclarationLocal(t *testing.T) { src := `package foo func Main() int { diff --git a/pkg/compiler/if_test.go b/pkg/compiler/if_test.go index 23077ecf9..038c8fe92 100644 --- a/pkg/compiler/if_test.go +++ b/pkg/compiler/if_test.go @@ -1,6 +1,7 @@ package compiler_test import ( + "fmt" "math/big" "testing" ) @@ -91,3 +92,32 @@ func TestNestedIF(t *testing.T) { ` eval(t, src, big.NewInt(0)) } + +func TestInitIF(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + src := `package foo + func Main() int { + if a := 42; true { + return a + } + return 0 + }` + eval(t, src, big.NewInt(42)) + }) + t.Run("Shadow", func(t *testing.T) { + srcTmpl := `package foo + func Main() int { + a := 11 + if a := 42; %v { + return a + } + return a + }` + t.Run("True", func(t *testing.T) { + eval(t, fmt.Sprintf(srcTmpl, true), big.NewInt(42)) + }) + t.Run("False", func(t *testing.T) { + eval(t, fmt.Sprintf(srcTmpl, false), big.NewInt(11)) + }) + }) +} diff --git a/pkg/compiler/lambda_test.go b/pkg/compiler/lambda_test.go index d9ab32613..f3e7132a7 100644 --- a/pkg/compiler/lambda_test.go +++ b/pkg/compiler/lambda_test.go @@ -13,3 +13,16 @@ func TestFuncLiteral(t *testing.T) { }` eval(t, src, big.NewInt(5)) } + +func TestCallInPlace(t *testing.T) { + src := `package foo + var a int = 1 + func Main() int { + func() { + a += 10 + }() + a += 100 + return a + }` + eval(t, src, big.NewInt(111)) +} diff --git a/pkg/compiler/panic_test.go b/pkg/compiler/panic_test.go index 54c51b9d8..3effe0956 100644 --- a/pkg/compiler/panic_test.go +++ b/pkg/compiler/panic_test.go @@ -1,13 +1,10 @@ package compiler_test import ( - "errors" "fmt" "math/big" "testing" - "github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/stretchr/testify/require" ) @@ -18,26 +15,19 @@ func TestPanic(t *testing.T) { }) t.Run("panic with message", func(t *testing.T) { - var logs []string src := getPanicSource(true, `"execution fault"`) v := vmAndCompile(t, src) - v.SyscallHandler = getLogHandler(&logs) require.Error(t, v.Run()) require.True(t, v.HasFailed()) - require.Equal(t, 1, len(logs)) - require.Equal(t, "execution fault", logs[0]) }) t.Run("panic with nil", func(t *testing.T) { - var logs []string src := getPanicSource(true, `nil`) v := vmAndCompile(t, src) - v.SyscallHandler = getLogHandler(&logs) require.Error(t, v.Run()) require.True(t, v.HasFailed()) - require.Equal(t, 0, len(logs)) }) } @@ -54,16 +44,3 @@ func getPanicSource(need bool, message string) string { } `, need, message) } - -func getLogHandler(logs *[]string) vm.SyscallHandler { - logID := interopnames.ToID([]byte(interopnames.SystemRuntimeLog)) - return func(v *vm.VM, id uint32) error { - if id != logID { - return errors.New("syscall not found") - } - - msg := v.Estack().Pop().String() - *logs = append(*logs, msg) - return nil - } -} diff --git a/pkg/vm/emit/emit.go b/pkg/vm/emit/emit.go index 3e693a915..a5ce828df 100644 --- a/pkg/vm/emit/emit.go +++ b/pkg/vm/emit/emit.go @@ -160,5 +160,5 @@ func AppCallWithOperationAndArgs(w *io.BinWriter, scriptHash util.Uint160, opera } func isInstructionJmp(op opcode.Opcode) bool { - return opcode.JMP <= op && op <= opcode.CALLL + return opcode.JMP <= op && op <= opcode.CALLL || op == opcode.ENDTRYL } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index ca68a2716..0a426fb6a 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1476,6 +1476,7 @@ func (v *VM) Call(ctx *Context, offset int) { newCtx.CheckReturn = false newCtx.local = nil newCtx.arguments = nil + newCtx.tryStack = NewStack("exception") v.istack.PushVal(newCtx) v.Jump(newCtx, offset) } @@ -1517,7 +1518,7 @@ func (v *VM) handleException() { pop := 0 ictx := v.istack.Peek(0).Value().(*Context) for ictx != nil { - e := ictx.tryStack.Peek(pop) + e := ictx.tryStack.Peek(0) for e != nil { ectx := e.Value().(*exceptionHandlingContext) if ectx.State == eFinally || (ectx.State == eCatch && !ectx.HasFinally()) { diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 65c51e2d4..198d2183e 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -1336,6 +1336,12 @@ func TestTRY(t *testing.T) { checkVMFailed(t, vm) }) }) + t.Run("ThrowInCall", func(t *testing.T) { + catchP := []byte{byte(opcode.CALL), 2, byte(opcode.PUSH1), byte(opcode.ADD), byte(opcode.THROW), byte(opcode.RET)} + inner := getTRYProgram(throw, catchP, []byte{byte(opcode.PUSH2)}) + // add 5 to the exception, mul to the result of inner finally (2) + getTRYTestFunc(47, inner, append(add5, byte(opcode.MUL)), add9)(t) + }) } func TestMEMCPY(t *testing.T) {