Merge pull request #1343 from nspcc-dev/compiler/recover

Support `defer` statement
This commit is contained in:
Roman Khimov 2020-08-27 17:33:12 +03:00 committed by GitHub
commit 3468f97836
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 384 additions and 63 deletions

View file

@ -19,11 +19,11 @@ a dialect of Go rather than a complete port of the language:
* goroutines, channels and garbage collection are not supported and will * goroutines, channels and garbage collection are not supported and will
never be because emulating that aspects of Go runtime on top of Neo VM is never be because emulating that aspects of Go runtime on top of Neo VM is
close to impossible close to impossible
* even though `panic()` is supported, `recover()` is not, `panic` shuts the * `defer` and `recover` are supported except for cases where panic occurs in
VM down `return` statement, because this complicates implementation and imposes runtime
* lambdas are not supported (#939) overhead for all contracts. This can easily be mitigated by first storing values
* it's not possible to rename imported interop packages, they won't work this in variables and returning the result.
way (#397, #913) * lambdas are supported, but closures are not.
## VM API (interop layer) ## VM API (interop layer)
Compiler translates interop function calls into NEO VM syscalls or (for custom Compiler translates interop function calls into NEO VM syscalls or (for custom

View file

@ -14,7 +14,7 @@ import (
var ( var (
// Go language builtin functions. // Go language builtin functions.
goBuiltins = []string{"len", "append", "panic", "make", "copy"} goBuiltins = []string{"len", "append", "panic", "make", "copy", "recover"}
// Custom builtin utility functions. // Custom builtin utility functions.
customBuiltins = []string{ customBuiltins = []string{
"FromAddress", "Equals", "FromAddress", "Equals",
@ -40,23 +40,30 @@ func (c *codegen) getIdentName(pkg string, name string) string {
// and returns number of variables initialized and // and returns number of variables initialized and
// true if any init functions were encountered. // true if any init functions were encountered.
func (c *codegen) traverseGlobals() (int, bool) { func (c *codegen) traverseGlobals() (int, bool) {
var hasDefer bool
var n int var n int
var hasInit bool var hasInit bool
c.ForEachFile(func(f *ast.File, _ *types.Package) { c.ForEachFile(func(f *ast.File, _ *types.Package) {
n += countGlobals(f) n += countGlobals(f)
if !hasInit { if !hasInit || !hasDefer {
ast.Inspect(f, func(node ast.Node) bool { ast.Inspect(f, func(node ast.Node) bool {
n, ok := node.(*ast.FuncDecl) switch n := node.(type) {
if ok { case *ast.FuncDecl:
if isInitFunc(n) { if isInitFunc(n) {
hasInit = true hasInit = true
} }
return !hasDefer
case *ast.DeferStmt:
hasDefer = true
return false return false
} }
return true return true
}) })
} }
}) })
if hasDefer {
n++
}
if n != 0 || hasInit { if n != 0 || hasInit {
if n > 255 { if n > 255 {
c.prog.BinWriter.Err = errors.New("too many global variables") c.prog.BinWriter.Err = errors.New("too many global variables")
@ -83,6 +90,11 @@ func (c *codegen) traverseGlobals() (int, bool) {
// encountered after will be recognized as globals. // encountered after will be recognized as globals.
c.scope = nil c.scope = nil
}) })
// store auxiliary variables after all others.
if hasDefer {
c.exceptionIndex = len(c.globals)
c.globals["<exception>"] = c.exceptionIndex
}
} }
return n, hasInit return n, hasInit
} }
@ -92,7 +104,7 @@ func (c *codegen) traverseGlobals() (int, bool) {
func countGlobals(f ast.Node) (i int) { func countGlobals(f ast.Node) (i int) {
ast.Inspect(f, func(node ast.Node) bool { ast.Inspect(f, func(node ast.Node) bool {
switch n := node.(type) { switch n := node.(type) {
// Skip all function declarations. // Skip all function declarations if we have already encountered `defer`.
case *ast.FuncDecl: case *ast.FuncDecl:
return false return false
// After skipping all funcDecls we are sure that each value spec // After skipping all funcDecls we are sure that each value spec

View file

@ -12,7 +12,6 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames"
"github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
@ -78,6 +77,9 @@ type codegen struct {
// packages contains packages in the order they were loaded. // packages contains packages in the order they were loaded.
packages []string packages []string
// exceptionIndex is the index of static slot where exception is stored.
exceptionIndex int
// documents contains paths to all files used by the program. // documents contains paths to all files used by the program.
documents []string documents []string
// docIndex maps file path to an index in documents array. // docIndex maps file path to an index in documents array.
@ -217,6 +219,11 @@ func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) {
// emitLoadVar loads specified variable to the evaluation stack. // emitLoadVar loads specified variable to the evaluation stack.
func (c *codegen) emitLoadVar(pkg string, name string) { func (c *codegen) emitLoadVar(pkg string, name string) {
t, i := c.getVarIndex(pkg, name) t, i := c.getVarIndex(pkg, name)
c.emitLoadByIndex(t, i)
}
// emitLoadByIndex loads specified variable type with index i.
func (c *codegen) emitLoadByIndex(t varType, i int) {
base, _ := getBaseOpcode(t) base, _ := getBaseOpcode(t)
if i < 7 { if i < 7 {
emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i)) emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i))
@ -232,6 +239,11 @@ func (c *codegen) emitStoreVar(pkg string, name string) {
return return
} }
t, i := c.getVarIndex(pkg, name) t, i := c.getVarIndex(pkg, name)
c.emitStoreByIndex(t, i)
}
// emitLoadByIndex stores top value in the specified variable type with index i.
func (c *codegen) emitStoreByIndex(t varType, i int) {
_, base := getBaseOpcode(t) _, base := getBaseOpcode(t)
if i < 7 { if i < 7 {
emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i)) emit.Opcode(c.prog.BinWriter, base+opcode.Opcode(i))
@ -553,6 +565,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
} }
c.processDefers()
c.saveSequencePoint(n) c.saveSequencePoint(n)
emit.Opcode(c.prog.BinWriter, opcode.RET) emit.Opcode(c.prog.BinWriter, opcode.RET)
return nil return nil
@ -565,6 +579,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
lElse := c.newLabel() lElse := c.newLabel()
lElseEnd := c.newLabel() lElseEnd := c.newLabel()
if n.Init != nil {
ast.Walk(c, n.Init)
}
if n.Cond != nil { if n.Cond != nil {
c.emitBoolExpr(n.Cond, true, false, lElse) c.emitBoolExpr(n.Cond, true, false, lElse)
} }
@ -708,6 +725,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
numArgs = len(n.Args) numArgs = len(n.Args)
isBuiltin bool isBuiltin bool
isFunc bool isFunc bool
isLiteral bool
) )
switch fun := n.Fun.(type) { switch fun := n.Fun.(type) {
@ -746,6 +764,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
ast.Walk(c, n.Args[0]) ast.Walk(c, n.Args[0])
c.emitConvert(stackitem.BufferT) c.emitConvert(stackitem.BufferT)
return nil return nil
case *ast.FuncLit:
isLiteral = true
} }
c.saveSequencePoint(n) c.saveSequencePoint(n)
@ -798,6 +818,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.emitLoadVar("", name) c.emitLoadVar("", name)
emit.Opcode(c.prog.BinWriter, opcode.CALLA) emit.Opcode(c.prog.BinWriter, opcode.CALLA)
} }
case isLiteral:
ast.Walk(c, n.Fun)
emit.Opcode(c.prog.BinWriter, opcode.CALLA)
case isSyscall(f): case isSyscall(f):
c.convertSyscall(n, f.pkg.Name(), f.name) c.convertSyscall(n, f.pkg.Name(), f.name)
default: default:
@ -820,6 +843,20 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
case *ast.DeferStmt:
catch := c.newLabel()
finally := c.newLabel()
param := make([]byte, 8)
binary.LittleEndian.PutUint16(param[0:], catch)
binary.LittleEndian.PutUint16(param[4:], finally)
emit.Instruction(c.prog.BinWriter, opcode.TRYL, param)
c.scope.deferStack = append(c.scope.deferStack, deferInfo{
catchLabel: catch,
finallyLabel: finally,
expr: n.Call,
})
return nil
case *ast.SelectorExpr: case *ast.SelectorExpr:
typ := c.typeOf(n.X) typ := c.typeOf(n.X)
if typ == nil { if typ == nil {
@ -1081,6 +1118,51 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return c return c
} }
// processDefers emits code for `defer` statements.
// TRY-related opcodes handle exception as follows:
// 1. CATCH block is executed only if exception has occured.
// 2. FINALLY block is always executed, but after catch block.
// Go `defer` statements are a bit different:
// 1. `defer` is always executed irregardless of whether an exception has occured.
// 2. `recover` can or can not handle a possible exception.
// Thus we use the following approach:
// 1. Throwed exception is saved in a static field X, static fields Y and is set to true.
// 2. CATCH and FINALLY blocks are the same, and both contain the same CALLs.
// 3. In CATCH block we set Y to true and emit default return values if it is the last defer.
// 4. Execute FINALLY block only if Y is false.
func (c *codegen) processDefers() {
for i := len(c.scope.deferStack) - 1; i >= 0; i-- {
stmt := c.scope.deferStack[i]
after := c.newLabel()
emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after)
c.setLabel(stmt.catchLabel)
c.emitStoreByIndex(varGlobal, c.exceptionIndex)
emit.Int(c.prog.BinWriter, 1)
c.emitStoreByIndex(varLocal, c.scope.finallyProcessedIndex)
ast.Walk(c, stmt.expr)
if i == 0 {
// After panic, default values must be returns, except for named returns,
// which we don't support here for now.
for i := len(c.scope.decl.Type.Results.List) - 1; i >= 0; i-- {
c.emitDefault(c.typeOf(c.scope.decl.Type.Results.List[i].Type))
}
}
emit.Jmp(c.prog.BinWriter, opcode.ENDTRYL, after)
c.setLabel(stmt.finallyLabel)
before := c.newLabel()
c.emitLoadByIndex(varLocal, c.scope.finallyProcessedIndex)
emit.Jmp(c.prog.BinWriter, opcode.JMPIFL, before)
ast.Walk(c, stmt.expr)
c.setLabel(before)
emit.Int(c.prog.BinWriter, 0)
c.emitStoreByIndex(varLocal, c.scope.finallyProcessedIndex)
emit.Opcode(c.prog.BinWriter, opcode.ENDFINALLY)
c.setLabel(after)
}
}
func (c *codegen) rangeLoadKey() { func (c *codegen) rangeLoadKey() {
emit.Int(c.prog.BinWriter, 2) emit.Int(c.prog.BinWriter, 2)
emit.Opcode(c.prog.BinWriter, opcode.PICK) // load keys emit.Opcode(c.prog.BinWriter, opcode.PICK) // load keys
@ -1428,17 +1510,11 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) {
} }
} }
case "panic": case "panic":
arg := expr.Args[0] emit.Opcode(c.prog.BinWriter, opcode.THROW)
if isExprNil(arg) { case "recover":
emit.Opcode(c.prog.BinWriter, opcode.DROP) c.emitLoadByIndex(varGlobal, c.exceptionIndex)
emit.Opcode(c.prog.BinWriter, opcode.THROW) emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL)
} else if isString(c.typeInfo.Types[arg].Type) { c.emitStoreByIndex(varGlobal, c.exceptionIndex)
ast.Walk(c, arg)
emit.Syscall(c.prog.BinWriter, interopnames.SystemRuntimeLog)
emit.Opcode(c.prog.BinWriter, opcode.THROW)
} else {
c.prog.Err = errors.New("panic should have string or nil argument")
}
case "ToInteger", "ToByteArray", "ToBool": case "ToInteger", "ToByteArray", "ToBool":
typ := stackitem.IntegerT typ := stackitem.IntegerT
switch name { switch name {
@ -1482,8 +1558,6 @@ func transformArgs(fun ast.Expr, args []ast.Expr) []ast.Expr {
} }
case *ast.Ident: case *ast.Ident:
switch f.Name { switch f.Name {
case "panic":
return args[1:]
case "make", "copy": case "make", "copy":
return nil return nil
} }
@ -1778,27 +1852,32 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) {
case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL,
opcode.JMPEQ, opcode.JMPNE, opcode.JMPEQ, opcode.JMPNE,
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT:
case opcode.TRYL:
nextIP := ctx.NextIP()
catchArg := b[nextIP-8:]
_, err := c.replaceLabelWithOffset(ctx.IP(), catchArg)
if err != nil {
return nil, err
}
finallyArg := b[nextIP-4:]
_, err = c.replaceLabelWithOffset(ctx.IP(), finallyArg)
if err != nil {
return nil, err
}
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL,
opcode.JMPEQL, opcode.JMPNEL, opcode.JMPEQL, opcode.JMPNEL,
opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL,
opcode.CALLL, opcode.PUSHA: opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL:
// we can't use arg returned by ctx.Next() because it is copied // we can't use arg returned by ctx.Next() because it is copied
nextIP := ctx.NextIP() nextIP := ctx.NextIP()
arg := b[nextIP-4:] arg := b[nextIP-4:]
offset, err := c.replaceLabelWithOffset(ctx.IP(), arg)
index := binary.LittleEndian.Uint16(arg) if err != nil {
if int(index) > len(c.l) { return nil, err
return nil, fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l))
}
offset := c.l[index] - nextIP + 5
if offset > math.MaxInt32 || offset < math.MinInt32 {
return nil, fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)",
nextIP-5, offset, math.MaxInt32, math.MinInt32)
} }
if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 { if op != opcode.PUSHA && math.MinInt8 <= offset && offset <= math.MaxInt8 {
offsets = append(offsets, ctx.IP()) offsets = append(offsets, ctx.IP())
} }
binary.LittleEndian.PutUint32(arg, uint32(offset))
} }
} }
// Correct function ip range. // Correct function ip range.
@ -1820,6 +1899,20 @@ func (c *codegen) writeJumps(b []byte) ([]byte, error) {
return shortenJumps(b, offsets), nil return shortenJumps(b, offsets), nil
} }
func (c *codegen) replaceLabelWithOffset(ip int, arg []byte) (int, error) {
index := binary.LittleEndian.Uint16(arg)
if int(index) > len(c.l) {
return 0, fmt.Errorf("unexpected label number: %d (max %d)", index, len(c.l))
}
offset := c.l[index] - ip
if offset > math.MaxInt32 || offset < math.MinInt32 {
return 0, fmt.Errorf("label offset is too big at the instruction %d: %d (max %d, min %d)",
ip, offset, math.MaxInt32, math.MinInt32)
}
binary.LittleEndian.PutUint32(arg, uint32(offset))
return offset, nil
}
// longToShortRemoveCount is a difference between short and long instruction sizes in bytes. // longToShortRemoveCount is a difference between short and long instruction sizes in bytes.
const longToShortRemoveCount = 3 const longToShortRemoveCount = 3
@ -1844,18 +1937,34 @@ func shortenJumps(b []byte, offsets []int) []byte {
switch op { switch op {
case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL, case opcode.JMP, opcode.JMPIFNOT, opcode.JMPIF, opcode.CALL,
opcode.JMPEQ, opcode.JMPNE, opcode.JMPEQ, opcode.JMPNE,
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT, opcode.ENDTRY:
offset := int(int8(b[nextIP-1])) offset := int(int8(b[nextIP-1]))
offset += calcOffsetCorrection(ip, ip+offset, offsets) offset += calcOffsetCorrection(ip, ip+offset, offsets)
b[nextIP-1] = byte(offset) b[nextIP-1] = byte(offset)
case opcode.TRY:
catchOffset := int(int8(b[nextIP-2]))
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets)
b[nextIP-1] = byte(catchOffset)
finallyOffset := int(int8(b[nextIP-1]))
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets)
b[nextIP-1] = byte(finallyOffset)
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL,
opcode.JMPEQL, opcode.JMPNEL, opcode.JMPEQL, opcode.JMPNEL,
opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL,
opcode.CALLL, opcode.PUSHA: opcode.CALLL, opcode.PUSHA, opcode.ENDTRYL:
arg := b[nextIP-4:] arg := b[nextIP-4:]
offset := int(int32(binary.LittleEndian.Uint32(arg))) offset := int(int32(binary.LittleEndian.Uint32(arg)))
offset += calcOffsetCorrection(ip, ip+offset, offsets) offset += calcOffsetCorrection(ip, ip+offset, offsets)
binary.LittleEndian.PutUint32(arg, uint32(offset)) binary.LittleEndian.PutUint32(arg, uint32(offset))
case opcode.TRYL:
arg := b[nextIP-8:]
catchOffset := int(int32(binary.LittleEndian.Uint32(arg)))
catchOffset += calcOffsetCorrection(ip, ip+catchOffset, offsets)
binary.LittleEndian.PutUint32(arg, uint32(catchOffset))
arg = b[nextIP-4:]
finallyOffset := int(int32(binary.LittleEndian.Uint32(arg)))
finallyOffset += calcOffsetCorrection(ip, ip+finallyOffset, offsets)
binary.LittleEndian.PutUint32(arg, uint32(finallyOffset))
} }
} }
@ -1939,6 +2048,8 @@ func toShortForm(op opcode.Opcode) opcode.Opcode {
return opcode.JMPLT return opcode.JMPLT
case opcode.CALLL: case opcode.CALLL:
return opcode.CALL return opcode.CALL
case opcode.ENDTRYL:
return opcode.ENDTRY
default: default:
panic(fmt.Errorf("invalid opcode: %s", op)) panic(fmt.Errorf("invalid opcode: %s", op))
} }

139
pkg/compiler/defer_test.go Normal file
View file

@ -0,0 +1,139 @@
package compiler_test
import (
"math/big"
"testing"
)
func TestDefer(t *testing.T) {
t.Run("Simple", func(t *testing.T) {
src := `package main
var a int
func Main() int {
return h() + a
}
func h() int {
defer f()
return 1
}
func f() { a += 2 }`
eval(t, src, big.NewInt(3))
})
t.Run("ValueUnchanged", func(t *testing.T) {
src := `package main
var a int
func Main() int {
defer f()
a = 3
return a
}
func f() { a += 2 }`
eval(t, src, big.NewInt(3))
})
t.Run("Function", func(t *testing.T) {
src := `package main
var a int
func Main() int {
return h() + a
}
func h() int {
defer f()
a = 3
return g()
}
func g() int {
a++
return a
}
func f() { a += 2 }`
eval(t, src, big.NewInt(10))
})
t.Run("MultipleDefers", func(t *testing.T) {
src := `package main
var a int
func Main() int {
return h() + a
}
func h() int {
defer f()
defer g()
a = 3
return a
}
func g() { a *= 2 }
func f() { a += 2 }`
eval(t, src, big.NewInt(11))
})
t.Run("FunctionLiteral", func(t *testing.T) {
src := `package main
var a int
func Main() int {
return h() + a
}
func h() int {
defer func() {
a = 10
}()
a = 3
return a
}`
eval(t, src, big.NewInt(13))
})
}
func TestRecover(t *testing.T) {
t.Run("Panic", func(t *testing.T) {
src := `package foo
var a int
func Main() int {
return h() + a
}
func h() int {
defer func() {
if r := recover(); r != nil {
a = 3
} else {
a = 4
}
}()
a = 1
panic("msg")
return a
}`
eval(t, src, big.NewInt(3))
})
t.Run("NoPanic", func(t *testing.T) {
src := `package foo
var a int
func Main() int {
return h() + a
}
func h() int {
defer func() {
if r := recover(); r != nil {
a = 3
} else {
a = 4
}
}()
a = 1
return a
}`
eval(t, src, big.NewInt(5))
})
t.Run("PanicInDefer", func(t *testing.T) {
src := `package foo
var a int
func Main() int {
return h() + a
}
func h() int {
defer func() { a += 2; _ = recover() }()
defer func() { a *= 3; _ = recover(); panic("again") }()
a = 1
panic("msg")
return a
}`
eval(t, src, big.NewInt(5))
})
}

View file

@ -30,6 +30,12 @@ type funcScope struct {
// Variables together with it's type in neo-vm. // Variables together with it's type in neo-vm.
variables []string variables []string
// deferStack is a stack containing encountered `defer` statements.
deferStack []deferInfo
// finallyProcessed is a index of static slot with boolean flag determining
// if `defer` statement was already processed.
finallyProcessedIndex int
// Local variables // Local variables
vars varScope vars varScope
@ -45,6 +51,12 @@ type funcScope struct {
i int i int
} }
type deferInfo struct {
catchLabel uint16
finallyLabel uint16
expr *ast.CallExpr
}
func (c *codegen) newFuncScope(decl *ast.FuncDecl, label uint16) *funcScope { func (c *codegen) newFuncScope(decl *ast.FuncDecl, label uint16) *funcScope {
var name string var name string
if decl.Name != nil { if decl.Name != nil {
@ -115,6 +127,7 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
func (c *funcScope) countLocals() int { func (c *funcScope) countLocals() int {
size := 0 size := 0
hasDefer := false
ast.Inspect(c.decl, func(n ast.Node) bool { ast.Inspect(c.decl, func(n ast.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *ast.FuncType: case *ast.FuncType:
@ -124,8 +137,11 @@ func (c *funcScope) countLocals() int {
} }
case *ast.AssignStmt: case *ast.AssignStmt:
if n.Tok == token.DEFINE { if n.Tok == token.DEFINE {
size += len(n.Rhs) size += len(n.Lhs)
} }
case *ast.DeferStmt:
hasDefer = true
return false
case *ast.ReturnStmt, *ast.IfStmt: case *ast.ReturnStmt, *ast.IfStmt:
size++ size++
// This handles the inline GenDecl like "var x = 2" // This handles the inline GenDecl like "var x = 2"
@ -143,6 +159,10 @@ func (c *funcScope) countLocals() int {
} }
return true return true
}) })
if hasDefer {
c.finallyProcessedIndex = size
size++
}
return size return size
} }

View file

@ -38,6 +38,18 @@ func TestMultiDeclaration(t *testing.T) {
eval(t, src, big.NewInt(6)) eval(t, src, big.NewInt(6))
} }
func TestCountLocal(t *testing.T) {
src := `package foo
func Main() int {
a, b, c, d := f()
return a + b + c + d
}
func f() (int, int, int, int) {
return 1, 2, 3, 4
}`
eval(t, src, big.NewInt(10))
}
func TestMultiDeclarationLocal(t *testing.T) { func TestMultiDeclarationLocal(t *testing.T) {
src := `package foo src := `package foo
func Main() int { func Main() int {

View file

@ -1,6 +1,7 @@
package compiler_test package compiler_test
import ( import (
"fmt"
"math/big" "math/big"
"testing" "testing"
) )
@ -91,3 +92,32 @@ func TestNestedIF(t *testing.T) {
` `
eval(t, src, big.NewInt(0)) eval(t, src, big.NewInt(0))
} }
func TestInitIF(t *testing.T) {
t.Run("Simple", func(t *testing.T) {
src := `package foo
func Main() int {
if a := 42; true {
return a
}
return 0
}`
eval(t, src, big.NewInt(42))
})
t.Run("Shadow", func(t *testing.T) {
srcTmpl := `package foo
func Main() int {
a := 11
if a := 42; %v {
return a
}
return a
}`
t.Run("True", func(t *testing.T) {
eval(t, fmt.Sprintf(srcTmpl, true), big.NewInt(42))
})
t.Run("False", func(t *testing.T) {
eval(t, fmt.Sprintf(srcTmpl, false), big.NewInt(11))
})
})
}

View file

@ -13,3 +13,16 @@ func TestFuncLiteral(t *testing.T) {
}` }`
eval(t, src, big.NewInt(5)) eval(t, src, big.NewInt(5))
} }
func TestCallInPlace(t *testing.T) {
src := `package foo
var a int = 1
func Main() int {
func() {
a += 10
}()
a += 100
return a
}`
eval(t, src, big.NewInt(111))
}

View file

@ -1,13 +1,10 @@
package compiler_test package compiler_test
import ( import (
"errors"
"fmt" "fmt"
"math/big" "math/big"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -18,26 +15,19 @@ func TestPanic(t *testing.T) {
}) })
t.Run("panic with message", func(t *testing.T) { t.Run("panic with message", func(t *testing.T) {
var logs []string
src := getPanicSource(true, `"execution fault"`) src := getPanicSource(true, `"execution fault"`)
v := vmAndCompile(t, src) v := vmAndCompile(t, src)
v.SyscallHandler = getLogHandler(&logs)
require.Error(t, v.Run()) require.Error(t, v.Run())
require.True(t, v.HasFailed()) require.True(t, v.HasFailed())
require.Equal(t, 1, len(logs))
require.Equal(t, "execution fault", logs[0])
}) })
t.Run("panic with nil", func(t *testing.T) { t.Run("panic with nil", func(t *testing.T) {
var logs []string
src := getPanicSource(true, `nil`) src := getPanicSource(true, `nil`)
v := vmAndCompile(t, src) v := vmAndCompile(t, src)
v.SyscallHandler = getLogHandler(&logs)
require.Error(t, v.Run()) require.Error(t, v.Run())
require.True(t, v.HasFailed()) require.True(t, v.HasFailed())
require.Equal(t, 0, len(logs))
}) })
} }
@ -54,16 +44,3 @@ func getPanicSource(need bool, message string) string {
} }
`, need, message) `, need, message)
} }
func getLogHandler(logs *[]string) vm.SyscallHandler {
logID := interopnames.ToID([]byte(interopnames.SystemRuntimeLog))
return func(v *vm.VM, id uint32) error {
if id != logID {
return errors.New("syscall not found")
}
msg := v.Estack().Pop().String()
*logs = append(*logs, msg)
return nil
}
}

View file

@ -160,5 +160,5 @@ func AppCallWithOperationAndArgs(w *io.BinWriter, scriptHash util.Uint160, opera
} }
func isInstructionJmp(op opcode.Opcode) bool { func isInstructionJmp(op opcode.Opcode) bool {
return opcode.JMP <= op && op <= opcode.CALLL return opcode.JMP <= op && op <= opcode.CALLL || op == opcode.ENDTRYL
} }

View file

@ -1476,6 +1476,7 @@ func (v *VM) Call(ctx *Context, offset int) {
newCtx.CheckReturn = false newCtx.CheckReturn = false
newCtx.local = nil newCtx.local = nil
newCtx.arguments = nil newCtx.arguments = nil
newCtx.tryStack = NewStack("exception")
v.istack.PushVal(newCtx) v.istack.PushVal(newCtx)
v.Jump(newCtx, offset) v.Jump(newCtx, offset)
} }
@ -1517,7 +1518,7 @@ func (v *VM) handleException() {
pop := 0 pop := 0
ictx := v.istack.Peek(0).Value().(*Context) ictx := v.istack.Peek(0).Value().(*Context)
for ictx != nil { for ictx != nil {
e := ictx.tryStack.Peek(pop) e := ictx.tryStack.Peek(0)
for e != nil { for e != nil {
ectx := e.Value().(*exceptionHandlingContext) ectx := e.Value().(*exceptionHandlingContext)
if ectx.State == eFinally || (ectx.State == eCatch && !ectx.HasFinally()) { if ectx.State == eFinally || (ectx.State == eCatch && !ectx.HasFinally()) {

View file

@ -1336,6 +1336,12 @@ func TestTRY(t *testing.T) {
checkVMFailed(t, vm) checkVMFailed(t, vm)
}) })
}) })
t.Run("ThrowInCall", func(t *testing.T) {
catchP := []byte{byte(opcode.CALL), 2, byte(opcode.PUSH1), byte(opcode.ADD), byte(opcode.THROW), byte(opcode.RET)}
inner := getTRYProgram(throw, catchP, []byte{byte(opcode.PUSH2)})
// add 5 to the exception, mul to the result of inner finally (2)
getTRYTestFunc(47, inner, append(add5, byte(opcode.MUL)), add9)(t)
})
} }
func TestMEMCPY(t *testing.T) { func TestMEMCPY(t *testing.T) {