From 5d82b82efb666c4c94dd4289b6a364a3b6fe11e8 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 21 Aug 2020 15:37:46 +0300 Subject: [PATCH] compiler: support `recover()` --- pkg/compiler/analysis.go | 22 +++++++++--- pkg/compiler/codegen.go | 71 +++++++++++++++++++++++++++++++++++--- pkg/compiler/defer_test.go | 57 ++++++++++++++++++++++++++++++ pkg/compiler/func_scope.go | 12 +++++++ 4 files changed, 153 insertions(+), 9 deletions(-) diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 62ca08ce6..730a9bd89 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -14,7 +14,7 @@ import ( var ( // Go language builtin functions. - goBuiltins = []string{"len", "append", "panic", "make", "copy"} + goBuiltins = []string{"len", "append", "panic", "make", "copy", "recover"} // Custom builtin utility functions. customBuiltins = []string{ "FromAddress", "Equals", @@ -40,23 +40,30 @@ func (c *codegen) getIdentName(pkg string, name string) string { // and returns number of variables initialized and // true if any init functions were encountered. func (c *codegen) traverseGlobals() (int, bool) { + var hasDefer bool var n int var hasInit bool c.ForEachFile(func(f *ast.File, _ *types.Package) { n += countGlobals(f) - if !hasInit { + if !hasInit || !hasDefer { ast.Inspect(f, func(node ast.Node) bool { - n, ok := node.(*ast.FuncDecl) - if ok { + switch n := node.(type) { + case *ast.FuncDecl: if isInitFunc(n) { hasInit = true } + return !hasDefer + case *ast.DeferStmt: + hasDefer = true return false } return true }) } }) + if hasDefer { + n++ + } if n != 0 || hasInit { if n > 255 { c.prog.BinWriter.Err = errors.New("too many global variables") @@ -83,6 +90,11 @@ func (c *codegen) traverseGlobals() (int, bool) { // encountered after will be recognized as globals. c.scope = nil }) + // store auxiliary variables after all others. + if hasDefer { + c.exceptionIndex = len(c.globals) + c.globals[""] = c.exceptionIndex + } } return n, hasInit } @@ -92,7 +104,7 @@ func (c *codegen) traverseGlobals() (int, bool) { func countGlobals(f ast.Node) (i int) { ast.Inspect(f, func(node ast.Node) bool { switch n := node.(type) { - // Skip all function declarations. + // Skip all function declarations if we have already encountered `defer`. case *ast.FuncDecl: return false // After skipping all funcDecls we are sure that each value spec diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index c040b771d..a5fe85d1c 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -77,6 +77,9 @@ type codegen struct { // packages contains packages in the order they were loaded. packages []string + // exceptionIndex is the index of static slot where exception is stored. + exceptionIndex int + // documents contains paths to all files used by the program. documents []string // docIndex maps file path to an index in documents array. @@ -216,6 +219,11 @@ func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) { // emitLoadVar loads specified variable to the evaluation stack. func (c *codegen) emitLoadVar(pkg string, name string) { t, i := c.getVarIndex(pkg, name) + c.emitLoadByIndex(t, i) +} + +// emitLoadByIndex loads specified variable type with index i. +func (c *codegen) emitLoadByIndex(t varType, i int) { base, _ := getBaseOpcode(t) if i < 7 { emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i)) @@ -231,6 +239,11 @@ func (c *codegen) emitStoreVar(pkg string, name string) { return } t, i := c.getVarIndex(pkg, name) + c.emitStoreByIndex(t, i) +} + +// emitLoadByIndex stores top value in the specified variable type with index i. +func (c *codegen) emitStoreByIndex(t varType, i int) { _, base := getBaseOpcode(t) if i < 7 { emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i)) @@ -831,11 +844,14 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.DeferStmt: + catch := c.newLabel() finally := c.newLabel() param := make([]byte, 8) + binary.LittleEndian.PutUint16(param[0:], catch) binary.LittleEndian.PutUint16(param[4:], finally) emit.Instruction(c.prog.BinWriter, opcode.TRYL, param) c.scope.deferStack = append(c.scope.deferStack, deferInfo{ + catchLabel: catch, finallyLabel: finally, expr: n.Call, }) @@ -1103,14 +1119,45 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } // processDefers emits code for `defer` statements. +// TRY-related opcodes handle exception as follows: +// 1. CATCH block is executed only if exception has occured. +// 2. FINALLY block is always executed, but after catch block. +// Go `defer` statements are a bit different: +// 1. `defer` is always executed irregardless of whether an exception has occured. +// 2. `recover` can or can not handle a possible exception. +// Thus we use the following approach: +// 1. Throwed exception is saved in a static field X, static fields Y and is set to true. +// 2. CATCH and FINALLY blocks are the same, and both contain the same CALLs. +// 3. In CATCH block we set Y to true and emit default return values if it is the last defer. +// 4. Execute FINALLY block only if Y is false. func (c *codegen) processDefers() { for i := len(c.scope.deferStack) - 1; i >= 0; i-- { stmt := c.scope.deferStack[i] after := c.newLabel() emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after) - c.setLabel(stmt.finallyLabel) - // Execute body. + + c.setLabel(stmt.catchLabel) + c.emitStoreByIndex(varGlobal, c.exceptionIndex) + emit.Int(c.prog.BinWriter, 1) + c.emitStoreByIndex(varLocal, c.scope.finallyProcessedIndex) ast.Walk(c, stmt.expr) + if i == 0 { + // After panic, default values must be returns, except for named returns, + // which we don't support here for now. + for i := len(c.scope.decl.Type.Results.List) - 1; i >= 0; i-- { + c.emitDefault(c.typeOf(c.scope.decl.Type.Results.List[i].Type)) + } + } + emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after) + + c.setLabel(stmt.finallyLabel) + before := c.newLabel() + c.emitLoadByIndex(varLocal, c.scope.finallyProcessedIndex) + emit.Jmp(c.prog.BinWriter, opcode.JMPIFL, before) + ast.Walk(c, stmt.expr) + c.setLabel(before) + emit.Int(c.prog.BinWriter, 0) + c.emitStoreByIndex(varLocal, c.scope.finallyProcessedIndex) emit.Opcode(c.prog.BinWriter, opcode.ENDFINALLY) c.setLabel(after) } @@ -1464,6 +1511,10 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) { } case "panic": emit.Opcode(c.prog.BinWriter, opcode.THROW) + case "recover": + c.emitLoadByIndex(varGlobal, c.exceptionIndex) + emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL) + c.emitStoreByIndex(varGlobal, c.exceptionIndex) case "ToInteger", "ToByteArray", "ToBool": typ := stackitem.IntegerT switch name { @@ -1803,8 +1854,13 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) { opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: case opcode.TRYL: nextIP := ctx.NextIP() + catchArg := b[nextIP-8:] + _, err := c.replaceLabelWithOffset(ctx.IP(), catchArg) + if err != nil { + return nil, err + } finallyArg := b[nextIP-4:] - _, err := c.replaceLabelWithOffset(ctx.IP(), finallyArg) + _, err = c.replaceLabelWithOffset(ctx.IP(), finallyArg) if err != nil { return nil, err } @@ -1886,6 +1942,9 @@ func shortenJumps(b []byte, offsets []int) []byte { offset += calcOffsetCorrection(ip, ip+offset, offsets) b[nextIP-1] = byte(offset) case opcode.TRY: + catchOffset := int(int8(b[nextIP-2])) + catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets) + b[nextIP-1] = byte(catchOffset) finallyOffset := int(int8(b[nextIP-1])) finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) b[nextIP-1] = byte(finallyOffset) @@ -1898,7 +1957,11 @@ func shortenJumps(b []byte, offsets []int) []byte { offset += calcOffsetCorrection(ip, ip+offset, offsets) binary.LittleEndian.PutUint32(arg, uint32(offset)) case opcode.TRYL: - arg := b[nextIP-4:] + arg := b[nextIP-8:] + catchOffset := int(int32(binary.LittleEndian.Uint32(arg))) + catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets) + binary.LittleEndian.PutUint32(arg, uint32(catchOffset)) + arg = b[nextIP-4:] finallyOffset := int(int32(binary.LittleEndian.Uint32(arg))) finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) binary.LittleEndian.PutUint32(arg, uint32(finallyOffset)) diff --git a/pkg/compiler/defer_test.go b/pkg/compiler/defer_test.go index 8bc8ec440..4a2f8f3b0 100644 --- a/pkg/compiler/defer_test.go +++ b/pkg/compiler/defer_test.go @@ -80,3 +80,60 @@ func TestDefer(t *testing.T) { eval(t, src, big.NewInt(13)) }) } + +func TestRecover(t *testing.T) { + t.Run("Panic", func(t *testing.T) { + src := `package foo + var a int + func Main() int { + return h() + a + } + func h() int { + defer func() { + if r := recover(); r != nil { + a = 3 + } else { + a = 4 + } + }() + a = 1 + panic("msg") + return a + }` + eval(t, src, big.NewInt(3)) + }) + t.Run("NoPanic", func(t *testing.T) { + src := `package foo + var a int + func Main() int { + return h() + a + } + func h() int { + defer func() { + if r := recover(); r != nil { + a = 3 + } else { + a = 4 + } + }() + a = 1 + return a + }` + eval(t, src, big.NewInt(5)) + }) + t.Run("PanicInDefer", func(t *testing.T) { + src := `package foo + var a int + func Main() int { + return h() + a + } + func h() int { + defer func() { a += 2; _ = recover() }() + defer func() { a *= 3; _ = recover(); panic("again") }() + a = 1 + panic("msg") + return a + }` + eval(t, src, big.NewInt(5)) + }) +} diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index f258e9f2d..2eca590fe 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -32,6 +32,9 @@ type funcScope struct { // deferStack is a stack containing encountered `defer` statements. deferStack []deferInfo + // finallyProcessed is a index of static slot with boolean flag determining + // if `defer` statement was already processed. + finallyProcessedIndex int // Local variables vars varScope @@ -49,6 +52,7 @@ type funcScope struct { } type deferInfo struct { + catchLabel uint16 finallyLabel uint16 expr *ast.CallExpr } @@ -123,6 +127,7 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool { func (c *funcScope) countLocals() int { size := 0 + hasDefer := false ast.Inspect(c.decl, func(n ast.Node) bool { switch n := n.(type) { case *ast.FuncType: @@ -134,6 +139,9 @@ func (c *funcScope) countLocals() int { if n.Tok == token.DEFINE { size += len(n.Lhs) } + case *ast.DeferStmt: + hasDefer = true + return false case *ast.ReturnStmt, *ast.IfStmt: size++ // This handles the inline GenDecl like "var x = 2" @@ -151,6 +159,10 @@ func (c *funcScope) countLocals() int { } return true }) + if hasDefer { + c.finallyProcessedIndex = size + size++ + } return size }