diff --git a/pkg/compiler/binary_expr_test.go b/pkg/compiler/binary_expr_test.go index bf3f94c62..22bafd8a7 100644 --- a/pkg/compiler/binary_expr_test.go +++ b/pkg/compiler/binary_expr_test.go @@ -1,6 +1,7 @@ package compiler_test import ( + "fmt" "math/big" "testing" ) @@ -262,3 +263,101 @@ var binaryExprTestCases = []testCase{ func TestBinaryExprs(t *testing.T) { runTestCases(t, binaryExprTestCases) } + +func getBoolExprTestFunc(val bool, cond string) func(t *testing.T) { + srcTmpl := `package foo + var s = "str" + var v = 9 + var cond = %s + func Main() int { + if %s { + return 42 + } %s + return 17 + %s + }` + res := big.NewInt(42) + if !val { + res.SetInt64(17) + } + return func(t *testing.T) { + t.Run("AsExpression", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, cond, "cond", "", "") + eval(t, src, res) + }) + t.Run("InCondition", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, "true", cond, "", "") + eval(t, src, res) + }) + t.Run("InConditionWithElse", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, "true", cond, " else {", "}") + eval(t, src, res) + }) + } +} + +// TestBooleanExprs enumerates a lot of possible combinations of boolean expressions +// and tests if the result matches to that of Go. +func TestBooleanExprs(t *testing.T) { + trueExpr := []string{"true", "v < 10", "v <= 9", "v > 8", "v >= 9", "v == 9", "v != 8", `s == "str"`} + falseExpr := []string{"false", "v > 9", "v >= 10", "v < 9", "v <= 8", "v == 8", "v != 9", `s == "a"`} + t.Run("Single", func(t *testing.T) { + for _, s := range trueExpr { + t.Run(s, getBoolExprTestFunc(true, s)) + } + for _, s := range falseExpr { + t.Run(s, getBoolExprTestFunc(false, s)) + } + }) + + type arg struct { + val bool + s string + } + t.Run("Combine", func(t *testing.T) { + var double []arg + for _, e := range trueExpr { + double = append(double, arg{true, e + " || false"}) + double = append(double, arg{true, e + " && true"}) + } + for _, e := range falseExpr { + double = append(double, arg{false, e + " && true"}) + double = append(double, arg{false, e + " || false"}) + } + for i := range double { + t.Run(double[i].s, getBoolExprTestFunc(double[i].val, double[i].s)) + } + + var triple []arg + for _, a1 := range double { + for _, a2 := range double { + triple = append(triple, arg{a1.val || a2.val, fmt.Sprintf("(%s) || (%s)", a1.s, a2.s)}) + triple = append(triple, arg{a1.val && a2.val, fmt.Sprintf("(%s) && (%s)", a1.s, a2.s)}) + } + } + for i := range triple { + t.Run(triple[i].s, getBoolExprTestFunc(triple[i].val, triple[i].s)) + } + }) + return +} + +func TestShortCircuit(t *testing.T) { + srcTmpl := `package foo + var a = 1 + func inc() bool { a += 1; return %s } + func Main() int { + if %s { + return 41 + a + } + return 16 + a + }` + t.Run("||", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, "true", "a == 1 || inc()") + eval(t, src, big.NewInt(42)) + }) + t.Run("&&", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, "false", "a == 2 && inc()") + eval(t, src, big.NewInt(17)) + }) +} diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index fcae45fbf..a9cac98d8 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -1087,6 +1087,8 @@ func (c *codegen) emitJumpOnCondition(cond bool, jmpLabel uint16) { } } +// emitBoolExpr emits boolean expression. If needJump is true and expression evaluates to `cond`, +// jump to jmpLabel is performed and no item is left on stack. func (c *codegen) emitBoolExpr(n ast.Expr, needJump bool, cond bool, jmpLabel uint16) { if be, ok := n.(*ast.BinaryExpr); ok { c.emitBinaryExpr(be, needJump, cond, jmpLabel) @@ -1098,6 +1100,8 @@ func (c *codegen) emitBoolExpr(n ast.Expr, needJump bool, cond bool, jmpLabel ui } } +// emitBinaryExpr emits binary expression. If needJump is true and expression evaluates to `cond`, +// jump to jmpLabel is performed and no item is left on stack. func (c *codegen) emitBinaryExpr(n *ast.BinaryExpr, needJump bool, cond bool, jmpLabel uint16) { // The AST package will try to resolve all basic literals for us. // If the typeinfo.Value is not nil we know that the expr is resolved @@ -1150,10 +1154,21 @@ func (c *codegen) emitBinaryExpr(n *ast.BinaryExpr, needJump bool, cond bool, jm default: ast.Walk(c, n.X) ast.Walk(c, n.Y) - c.emitToken(n.Op, c.typeOf(n.X)) - if needJump { - c.emitJumpOnCondition(cond, jmpLabel) + typ := c.typeOf(n.X) + if !needJump { + c.emitToken(n.Op, typ) + return } + op, ok := getJumpForToken(n.Op, typ) + if !ok { + c.emitToken(n.Op, typ) + c.emitJumpOnCondition(cond, jmpLabel) + return + } + if !cond { + op = negateJmp(op) + } + emit.Jmp(c.prog.BinWriter, op, jmpLabel) } } @@ -1214,6 +1229,29 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 { return c.labels[labelWithType{name: name, typ: typ}] } +// For `&&` and `||` it return an opcode which jumps only if result is known: +// false && .. == false, true || .. = true +func getJumpForToken(tok token.Token, typ types.Type) (opcode.Opcode, bool) { + switch tok { + case token.GTR: + return opcode.JMPGTL, true + case token.GEQ: + return opcode.JMPGEL, true + case token.LSS: + return opcode.JMPLTL, true + case token.LEQ: + return opcode.JMPLEL, true + case token.EQL, token.NEQ: + if isNumber(typ) { + if tok == token.EQL { + return opcode.JMPEQL, true + } + return opcode.JMPNEL, true + } + } + return 0, false +} + // getByteArray returns byte array value from constant expr. // Only literals are supported. func (c *codegen) getByteArray(expr ast.Expr) []byte { @@ -1766,6 +1804,29 @@ func calcOffsetCorrection(ip, target int, offsets []int) int { return cnt } +func negateJmp(op opcode.Opcode) opcode.Opcode { + switch op { + case opcode.JMPIFL: + return opcode.JMPIFNOTL + case opcode.JMPIFNOTL: + return opcode.JMPIFL + case opcode.JMPEQL: + return opcode.JMPNEL + case opcode.JMPNEL: + return opcode.JMPEQL + case opcode.JMPGTL: + return opcode.JMPLEL + case opcode.JMPGEL: + return opcode.JMPLTL + case opcode.JMPLEL: + return opcode.JMPGTL + case opcode.JMPLTL: + return opcode.JMPGEL + default: + panic(fmt.Errorf("invalid opcode in negateJmp: %s", op)) + } +} + func toShortForm(op opcode.Opcode) opcode.Opcode { switch op { case opcode.JMPL: