diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 93ee4c711..95f3d03c9 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -553,6 +553,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } } + c.processDefers() + c.saveSequencePoint(n) emit.Opcode(c.prog.BinWriter, opcode.RET) return nil @@ -829,6 +831,17 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil + case *ast.DeferStmt: + finally := c.newLabel() + param := make([]byte, 8) + binary.LittleEndian.PutUint16(param[4:], finally) + emit.Instruction(c.prog.BinWriter, opcode.TRYL, param) + c.scope.deferStack = append(c.scope.deferStack, deferInfo{ + finallyLabel: finally, + expr: n.Call, + }) + return nil + case *ast.SelectorExpr: typ := c.typeOf(n.X) if typ == nil { @@ -1090,6 +1103,20 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return c } +// processDefers emits code for `defer` statements. +func (c *codegen) processDefers() { + for i := len(c.scope.deferStack) - 1; i >= 0; i-- { + stmt := c.scope.deferStack[i] + after := c.newLabel() + emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after) + c.setLabel(stmt.finallyLabel) + // Execute body. + ast.Walk(c, stmt.expr) + emit.Opcode(c.prog.BinWriter, opcode.ENDFINALLY) + c.setLabel(after) + } +} + func (c *codegen) rangeLoadKey() { emit.Int(c.prog.BinWriter, 2) emit.Opcode(c.prog.BinWriter, opcode.PICK) // load keys @@ -1787,27 +1814,27 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) { case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, opcode.JMPEQ, opcode.JMPNE, opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: + case opcode.TRYL: + nextIP := ctx.NextIP() + finallyArg := b[nextIP-4:] + _, err := c.replaceLabelWithOffset(ctx.IP(), finallyArg) + if err != nil { + return nil, err + } case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, - opcode.CALLL, opcode.PUSHA: + opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL: // we can't use arg returned by ctx.Next() because it is copied nextIP := ctx.NextIP() arg := b[nextIP-4:] - - index := binary.LittleEndian.Uint16(arg) - if int(index) > len(c.l) { - return nil, fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l)) - } - offset := c.l[index] - nextIP + 5 - if offset > math.MaxInt32 || offset < math.MinInt32 { - return nil, fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)", - nextIP-5, offset, math.MaxInt32, math.MinInt32) + offset, err := c.replaceLabelWithOffset(ctx.IP(), arg) + if err != nil { + return nil, err } if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 { offsets = append(offsets, ctx.IP()) } - binary.LittleEndian.PutUint32(arg, uint32(offset)) } } // Correct function ip range. @@ -1829,6 +1856,20 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) { return shortenJumps(b, offsets), nil } +func (c *codegen) replaceLabelWithOffset(ip int, arg []byte) (int, error) { + index := binary.LittleEndian.Uint16(arg) + if int(index) > len(c.l) { + return 0, fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l)) + } + offset := c.l[index] - ip + if offset > math.MaxInt32 || offset < math.MinInt32 { + return 0, fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)", + ip, offset, math.MaxInt32, math.MinInt32) + } + binary.LittleEndian.PutUint32(arg, uint32(offset)) + return offset, nil +} + // longToShortRemoveCount is a difference between short and long instruction sizes in bytes. const longToShortRemoveCount = 3 @@ -1853,18 +1894,27 @@ func shortenJumps(b []byte, offsets []int) []byte { switch op { case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, opcode.JMPEQ, opcode.JMPNE, - opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: + opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT, opcode.ENDTRY: offset := int(int8(b[nextIP-1])) offset += calcOffsetCorrection(ip, ip+offset, offsets) b[nextIP-1] = byte(offset) + case opcode.TRY: + finallyOffset := int(int8(b[nextIP-1])) + finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) + b[nextIP-1] = byte(finallyOffset) case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, - opcode.CALLL, opcode.PUSHA: + opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL: arg := b[nextIP-4:] offset := int(int32(binary.LittleEndian.Uint32(arg))) offset += calcOffsetCorrection(ip, ip+offset, offsets) binary.LittleEndian.PutUint32(arg, uint32(offset)) + case opcode.TRYL: + arg := b[nextIP-4:] + finallyOffset := int(int32(binary.LittleEndian.Uint32(arg))) + finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets) + binary.LittleEndian.PutUint32(arg, uint32(finallyOffset)) } } @@ -1948,6 +1998,8 @@ func toShortForm(op opcode.Opcode) opcode.Opcode { return opcode.JMPLT case opcode.CALLL: return opcode.CALL + case opcode.ENDTRYL: + return opcode.ENDTRY default: panic(fmt.Errorf("invalid opcode: %s", op)) } diff --git a/pkg/compiler/defer_test.go b/pkg/compiler/defer_test.go new file mode 100644 index 000000000..8bc8ec440 --- /dev/null +++ b/pkg/compiler/defer_test.go @@ -0,0 +1,82 @@ +package compiler_test + +import ( + "math/big" + "testing" +) + +func TestDefer(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + src := `package main + var a int + func Main() int { + return h() + a + } + func h() int { + defer f() + return 1 + } + func f() { a += 2 }` + eval(t, src, big.NewInt(3)) + }) + t.Run("ValueUnchanged", func(t *testing.T) { + src := `package main + var a int + func Main() int { + defer f() + a = 3 + return a + } + func f() { a += 2 }` + eval(t, src, big.NewInt(3)) + }) + t.Run("Function", func(t *testing.T) { + src := `package main + var a int + func Main() int { + return h() + a + } + func h() int { + defer f() + a = 3 + return g() + } + func g() int { + a++ + return a + } + func f() { a += 2 }` + eval(t, src, big.NewInt(10)) + }) + t.Run("MultipleDefers", func(t *testing.T) { + src := `package main + var a int + func Main() int { + return h() + a + } + func h() int { + defer f() + defer g() + a = 3 + return a + } + func g() { a *= 2 } + func f() { a += 2 }` + eval(t, src, big.NewInt(11)) + }) + t.Run("FunctionLiteral", func(t *testing.T) { + src := `package main + var a int + func Main() int { + return h() + a + } + func h() int { + defer func() { + a = 10 + }() + a = 3 + return a + }` + eval(t, src, big.NewInt(13)) + }) +} diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index 7a9e09ed6..f258e9f2d 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -30,6 +30,9 @@ type funcScope struct { // Variables together with it's type in neo-vm. variables []string + // deferStack is a stack containing encountered `defer` statements. + deferStack []deferInfo + // Local variables vars varScope @@ -45,6 +48,11 @@ type funcScope struct { i int } +type deferInfo struct { + finallyLabel uint16 + expr *ast.CallExpr +} + func (c *codegen) newFuncScope(decl *ast.FuncDecl, label uint16) *funcScope { var name string if decl.Name != nil { diff --git a/pkg/vm/emit/emit.go b/pkg/vm/emit/emit.go index 3e693a915..a5ce828df 100644 --- a/pkg/vm/emit/emit.go +++ b/pkg/vm/emit/emit.go @@ -160,5 +160,5 @@ func AppCallWithOperationAndArgs(w *io.BinWriter, scriptHash util.Uint160, opera } func isInstructionJmp(op opcode.Opcode) bool { - return opcode.JMP <= op && op <= opcode.CALLL + return opcode.JMP <= op && op <= opcode.CALLL || op == opcode.ENDTRYL }