compiler: support recover()

This commit is contained in:
Evgenii Stratonikov 2020-08-21 15:37:46 +03:00
parent 14ea3c2228
commit 5d82b82efb
4 changed files with 153 additions and 9 deletions

View file

@ -14,7 +14,7 @@ import (
var ( var (
// Go language builtin functions. // Go language builtin functions.
goBuiltins = []string{"len", "append", "panic", "make", "copy"} goBuiltins = []string{"len", "append", "panic", "make", "copy", "recover"}
// Custom builtin utility functions. // Custom builtin utility functions.
customBuiltins = []string{ customBuiltins = []string{
"FromAddress", "Equals", "FromAddress", "Equals",
@ -40,23 +40,30 @@ func (c *codegen) getIdentName(pkg string, name string) string {
// and returns number of variables initialized and // and returns number of variables initialized and
// true if any init functions were encountered. // true if any init functions were encountered.
func (c *codegen) traverseGlobals() (int, bool) { func (c *codegen) traverseGlobals() (int, bool) {
var hasDefer bool
var n int var n int
var hasInit bool var hasInit bool
c.ForEachFile(func(f *ast.File, _ *types.Package) { c.ForEachFile(func(f *ast.File, _ *types.Package) {
n += countGlobals(f) n += countGlobals(f)
if !hasInit { if !hasInit || !hasDefer {
ast.Inspect(f, func(node ast.Node) bool { ast.Inspect(f, func(node ast.Node) bool {
n, ok := node.(*ast.FuncDecl) switch n := node.(type) {
if ok { case *ast.FuncDecl:
if isInitFunc(n) { if isInitFunc(n) {
hasInit = true hasInit = true
} }
return !hasDefer
case *ast.DeferStmt:
hasDefer = true
return false return false
} }
return true return true
}) })
} }
}) })
if hasDefer {
n++
}
if n != 0 || hasInit { if n != 0 || hasInit {
if n > 255 { if n > 255 {
c.prog.BinWriter.Err = errors.New("too many global variables") c.prog.BinWriter.Err = errors.New("too many global variables")
@ -83,6 +90,11 @@ func (c *codegen) traverseGlobals() (int, bool) {
// encountered after will be recognized as globals. // encountered after will be recognized as globals.
c.scope = nil c.scope = nil
}) })
// store auxiliary variables after all others.
if hasDefer {
c.exceptionIndex = len(c.globals)
c.globals["<exception>"] = c.exceptionIndex
}
} }
return n, hasInit return n, hasInit
} }
@ -92,7 +104,7 @@ func (c *codegen) traverseGlobals() (int, bool) {
func countGlobals(f ast.Node) (i int) { func countGlobals(f ast.Node) (i int) {
ast.Inspect(f, func(node ast.Node) bool { ast.Inspect(f, func(node ast.Node) bool {
switch n := node.(type) { switch n := node.(type) {
// Skip all function declarations. // Skip all function declarations if we have already encountered `defer`.
case *ast.FuncDecl: case *ast.FuncDecl:
return false return false
// After skipping all funcDecls we are sure that each value spec // After skipping all funcDecls we are sure that each value spec

View file

@ -77,6 +77,9 @@ type codegen struct {
// packages contains packages in the order they were loaded. // packages contains packages in the order they were loaded.
packages []string packages []string
// exceptionIndex is the index of static slot where exception is stored.
exceptionIndex int
// documents contains paths to all files used by the program. // documents contains paths to all files used by the program.
documents []string documents []string
// docIndex maps file path to an index in documents array. // docIndex maps file path to an index in documents array.
@ -216,6 +219,11 @@ func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) {
// emitLoadVar loads specified variable to the evaluation stack. // emitLoadVar loads specified variable to the evaluation stack.
func (c *codegen) emitLoadVar(pkg string, name string) { func (c *codegen) emitLoadVar(pkg string, name string) {
t, i := c.getVarIndex(pkg, name) t, i := c.getVarIndex(pkg, name)
c.emitLoadByIndex(t, i)
}
// emitLoadByIndex loads specified variable type with index i.
func (c *codegen) emitLoadByIndex(t varType, i int) {
base, _ := getBaseOpcode(t) base, _ := getBaseOpcode(t)
if i < 7 { if i < 7 {
emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i)) emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i))
@ -231,6 +239,11 @@ func (c *codegen) emitStoreVar(pkg string, name string) {
return return
} }
t, i := c.getVarIndex(pkg, name) t, i := c.getVarIndex(pkg, name)
c.emitStoreByIndex(t, i)
}
// emitLoadByIndex stores top value in the specified variable type with index i.
func (c *codegen) emitStoreByIndex(t varType, i int) {
_, base := getBaseOpcode(t) _, base := getBaseOpcode(t)
if i < 7 { if i < 7 {
emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i)) emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i))
@ -831,11 +844,14 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
case *ast.DeferStmt: case *ast.DeferStmt:
catch := c.newLabel()
finally := c.newLabel() finally := c.newLabel()
param := make([]byte, 8) param := make([]byte, 8)
binary.LittleEndian.PutUint16(param[0:], catch)
binary.LittleEndian.PutUint16(param[4:], finally) binary.LittleEndian.PutUint16(param[4:], finally)
emit.Instruction(c.prog.BinWriter, opcode.TRYL, param) emit.Instruction(c.prog.BinWriter, opcode.TRYL, param)
c.scope.deferStack = append(c.scope.deferStack, deferInfo{ c.scope.deferStack = append(c.scope.deferStack, deferInfo{
catchLabel: catch,
finallyLabel: finally, finallyLabel: finally,
expr: n.Call, expr: n.Call,
}) })
@ -1103,14 +1119,45 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
// processDefers emits code for `defer` statements. // processDefers emits code for `defer` statements.
// TRY-related opcodes handle exception as follows:
// 1. CATCH block is executed only if exception has occured.
// 2. FINALLY block is always executed, but after catch block.
// Go `defer` statements are a bit different:
// 1. `defer` is always executed irregardless of whether an exception has occured.
// 2. `recover` can or can not handle a possible exception.
// Thus we use the following approach:
// 1. Throwed exception is saved in a static field X, static fields Y and is set to true.
// 2. CATCH and FINALLY blocks are the same, and both contain the same CALLs.
// 3. In CATCH block we set Y to true and emit default return values if it is the last defer.
// 4. Execute FINALLY block only if Y is false.
func (c *codegen) processDefers() { func (c *codegen) processDefers() {
for i := len(c.scope.deferStack) - 1; i >= 0; i-- { for i := len(c.scope.deferStack) - 1; i >= 0; i-- {
stmt := c.scope.deferStack[i] stmt := c.scope.deferStack[i]
after := c.newLabel() after := c.newLabel()
emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after) emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after)
c.setLabel(stmt.finallyLabel)
// Execute body. c.setLabel(stmt.catchLabel)
c.emitStoreByIndex(varGlobal, c.exceptionIndex)
emit.Int(c.prog.BinWriter, 1)
c.emitStoreByIndex(varLocal, c.scope.finallyProcessedIndex)
ast.Walk(c, stmt.expr) ast.Walk(c, stmt.expr)
if i == 0 {
// After panic, default values must be returns, except for named returns,
// which we don't support here for now.
for i := len(c.scope.decl.Type.Results.List) - 1; i >= 0; i-- {
c.emitDefault(c.typeOf(c.scope.decl.Type.Results.List[i].Type))
}
}
emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after)
c.setLabel(stmt.finallyLabel)
before := c.newLabel()
c.emitLoadByIndex(varLocal, c.scope.finallyProcessedIndex)
emit.Jmp(c.prog.BinWriter, opcode.JMPIFL, before)
ast.Walk(c, stmt.expr)
c.setLabel(before)
emit.Int(c.prog.BinWriter, 0)
c.emitStoreByIndex(varLocal, c.scope.finallyProcessedIndex)
emit.Opcode(c.prog.BinWriter, opcode.ENDFINALLY) emit.Opcode(c.prog.BinWriter, opcode.ENDFINALLY)
c.setLabel(after) c.setLabel(after)
} }
@ -1464,6 +1511,10 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) {
} }
case "panic": case "panic":
emit.Opcode(c.prog.BinWriter, opcode.THROW) emit.Opcode(c.prog.BinWriter, opcode.THROW)
case "recover":
c.emitLoadByIndex(varGlobal, c.exceptionIndex)
emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL)
c.emitStoreByIndex(varGlobal, c.exceptionIndex)
case "ToInteger", "ToByteArray", "ToBool": case "ToInteger", "ToByteArray", "ToBool":
typ := stackitem.IntegerT typ := stackitem.IntegerT
switch name { switch name {
@ -1803,8 +1854,13 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) {
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT:
case opcode.TRYL: case opcode.TRYL:
nextIP := ctx.NextIP() nextIP := ctx.NextIP()
catchArg := b[nextIP-8:]
_, err := c.replaceLabelWithOffset(ctx.IP(), catchArg)
if err != nil {
return nil, err
}
finallyArg := b[nextIP-4:] finallyArg := b[nextIP-4:]
_, err := c.replaceLabelWithOffset(ctx.IP(), finallyArg) _, err = c.replaceLabelWithOffset(ctx.IP(), finallyArg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1886,6 +1942,9 @@ func shortenJumps(b []byte, offsets []int) []byte {
offset += calcOffsetCorrection(ip, ip+offset, offsets) offset += calcOffsetCorrection(ip, ip+offset, offsets)
b[nextIP-1] = byte(offset) b[nextIP-1] = byte(offset)
case opcode.TRY: case opcode.TRY:
catchOffset := int(int8(b[nextIP-2]))
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets)
b[nextIP-1] = byte(catchOffset)
finallyOffset := int(int8(b[nextIP-1])) finallyOffset := int(int8(b[nextIP-1]))
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets)
b[nextIP-1] = byte(finallyOffset) b[nextIP-1] = byte(finallyOffset)
@ -1898,7 +1957,11 @@ func shortenJumps(b []byte, offsets []int) []byte {
offset += calcOffsetCorrection(ip, ip+offset, offsets) offset += calcOffsetCorrection(ip, ip+offset, offsets)
binary.LittleEndian.PutUint32(arg, uint32(offset)) binary.LittleEndian.PutUint32(arg, uint32(offset))
case opcode.TRYL: case opcode.TRYL:
arg := b[nextIP-4:] arg := b[nextIP-8:]
catchOffset := int(int32(binary.LittleEndian.Uint32(arg)))
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets)
binary.LittleEndian.PutUint32(arg, uint32(catchOffset))
arg = b[nextIP-4:]
finallyOffset := int(int32(binary.LittleEndian.Uint32(arg))) finallyOffset := int(int32(binary.LittleEndian.Uint32(arg)))
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets)
binary.LittleEndian.PutUint32(arg, uint32(finallyOffset)) binary.LittleEndian.PutUint32(arg, uint32(finallyOffset))

View file

@ -80,3 +80,60 @@ func TestDefer(t *testing.T) {
eval(t, src, big.NewInt(13)) eval(t, src, big.NewInt(13))
}) })
} }
func TestRecover(t *testing.T) {
t.Run("Panic", func(t *testing.T) {
src := `package foo
var a int
func Main() int {
return h() + a
}
func h() int {
defer func() {
if r := recover(); r != nil {
a = 3
} else {
a = 4
}
}()
a = 1
panic("msg")
return a
}`
eval(t, src, big.NewInt(3))
})
t.Run("NoPanic", func(t *testing.T) {
src := `package foo
var a int
func Main() int {
return h() + a
}
func h() int {
defer func() {
if r := recover(); r != nil {
a = 3
} else {
a = 4
}
}()
a = 1
return a
}`
eval(t, src, big.NewInt(5))
})
t.Run("PanicInDefer", func(t *testing.T) {
src := `package foo
var a int
func Main() int {
return h() + a
}
func h() int {
defer func() { a += 2; _ = recover() }()
defer func() { a *= 3; _ = recover(); panic("again") }()
a = 1
panic("msg")
return a
}`
eval(t, src, big.NewInt(5))
})
}

View file

@ -32,6 +32,9 @@ type funcScope struct {
// deferStack is a stack containing encountered `defer` statements. // deferStack is a stack containing encountered `defer` statements.
deferStack []deferInfo deferStack []deferInfo
// finallyProcessed is a index of static slot with boolean flag determining
// if `defer` statement was already processed.
finallyProcessedIndex int
// Local variables // Local variables
vars varScope vars varScope
@ -49,6 +52,7 @@ type funcScope struct {
} }
type deferInfo struct { type deferInfo struct {
catchLabel uint16
finallyLabel uint16 finallyLabel uint16
expr *ast.CallExpr expr *ast.CallExpr
} }
@ -123,6 +127,7 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
func (c *funcScope) countLocals() int { func (c *funcScope) countLocals() int {
size := 0 size := 0
hasDefer := false
ast.Inspect(c.decl, func(n ast.Node) bool { ast.Inspect(c.decl, func(n ast.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *ast.FuncType: case *ast.FuncType:
@ -134,6 +139,9 @@ func (c *funcScope) countLocals() int {
if n.Tok == token.DEFINE { if n.Tok == token.DEFINE {
size += len(n.Lhs) size += len(n.Lhs)
} }
case *ast.DeferStmt:
hasDefer = true
return false
case *ast.ReturnStmt, *ast.IfStmt: case *ast.ReturnStmt, *ast.IfStmt:
size++ size++
// This handles the inline GenDecl like "var x = 2" // This handles the inline GenDecl like "var x = 2"
@ -151,6 +159,10 @@ func (c *funcScope) countLocals() int {
} }
return true return true
}) })
if hasDefer {
c.finallyProcessedIndex = size
size++
}
return size return size
} }