compiler: make use of extended JMP* opcodes

This commit is contained in:
Evgenii Stratonikov 2020-08-23 18:00:23 +03:00
parent 51f3baf68e
commit 9dc3edf351
2 changed files with 163 additions and 3 deletions

View file

@ -1,6 +1,7 @@
package compiler_test package compiler_test
import ( import (
"fmt"
"math/big" "math/big"
"testing" "testing"
) )
@ -262,3 +263,101 @@ var binaryExprTestCases = []testCase{
func TestBinaryExprs(t *testing.T) { func TestBinaryExprs(t *testing.T) {
runTestCases(t, binaryExprTestCases) runTestCases(t, binaryExprTestCases)
} }
func getBoolExprTestFunc(val bool, cond string) func(t *testing.T) {
srcTmpl := `package foo
var s = "str"
var v = 9
var cond = %s
func Main() int {
if %s {
return 42
} %s
return 17
%s
}`
res := big.NewInt(42)
if !val {
res.SetInt64(17)
}
return func(t *testing.T) {
t.Run("AsExpression", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, cond, "cond", "", "")
eval(t, src, res)
})
t.Run("InCondition", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "true", cond, "", "")
eval(t, src, res)
})
t.Run("InConditionWithElse", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "true", cond, " else {", "}")
eval(t, src, res)
})
}
}
// TestBooleanExprs enumerates a lot of possible combinations of boolean expressions
// and tests if the result matches to that of Go.
func TestBooleanExprs(t *testing.T) {
trueExpr := []string{"true", "v < 10", "v <= 9", "v > 8", "v >= 9", "v == 9", "v != 8", `s == "str"`}
falseExpr := []string{"false", "v > 9", "v >= 10", "v < 9", "v <= 8", "v == 8", "v != 9", `s == "a"`}
t.Run("Single", func(t *testing.T) {
for _, s := range trueExpr {
t.Run(s, getBoolExprTestFunc(true, s))
}
for _, s := range falseExpr {
t.Run(s, getBoolExprTestFunc(false, s))
}
})
type arg struct {
val bool
s string
}
t.Run("Combine", func(t *testing.T) {
var double []arg
for _, e := range trueExpr {
double = append(double, arg{true, e + " || false"})
double = append(double, arg{true, e + " && true"})
}
for _, e := range falseExpr {
double = append(double, arg{false, e + " && true"})
double = append(double, arg{false, e + " || false"})
}
for i := range double {
t.Run(double[i].s, getBoolExprTestFunc(double[i].val, double[i].s))
}
var triple []arg
for _, a1 := range double {
for _, a2 := range double {
triple = append(triple, arg{a1.val || a2.val, fmt.Sprintf("(%s) || (%s)", a1.s, a2.s)})
triple = append(triple, arg{a1.val && a2.val, fmt.Sprintf("(%s) && (%s)", a1.s, a2.s)})
}
}
for i := range triple {
t.Run(triple[i].s, getBoolExprTestFunc(triple[i].val, triple[i].s))
}
})
return
}
func TestShortCircuit(t *testing.T) {
srcTmpl := `package foo
var a = 1
func inc() bool { a += 1; return %s }
func Main() int {
if %s {
return 41 + a
}
return 16 + a
}`
t.Run("||", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "true", "a == 1 || inc()")
eval(t, src, big.NewInt(42))
})
t.Run("&&", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "false", "a == 2 && inc()")
eval(t, src, big.NewInt(17))
})
}

View file

@ -1087,6 +1087,8 @@ func (c *codegen) emitJumpOnCondition(cond bool, jmpLabel uint16) {
} }
} }
// emitBoolExpr emits boolean expression. If needJump is true and expression evaluates to `cond`,
// jump to jmpLabel is performed and no item is left on stack.
func (c *codegen) emitBoolExpr(n ast.Expr, needJump bool, cond bool, jmpLabel uint16) { func (c *codegen) emitBoolExpr(n ast.Expr, needJump bool, cond bool, jmpLabel uint16) {
if be, ok := n.(*ast.BinaryExpr); ok { if be, ok := n.(*ast.BinaryExpr); ok {
c.emitBinaryExpr(be, needJump, cond, jmpLabel) c.emitBinaryExpr(be, needJump, cond, jmpLabel)
@ -1098,6 +1100,8 @@ func (c *codegen) emitBoolExpr(n ast.Expr, needJump bool, cond bool, jmpLabel ui
} }
} }
// emitBinaryExpr emits binary expression. If needJump is true and expression evaluates to `cond`,
// jump to jmpLabel is performed and no item is left on stack.
func (c *codegen) emitBinaryExpr(n *ast.BinaryExpr, needJump bool, cond bool, jmpLabel uint16) { func (c *codegen) emitBinaryExpr(n *ast.BinaryExpr, needJump bool, cond bool, jmpLabel uint16) {
// The AST package will try to resolve all basic literals for us. // The AST package will try to resolve all basic literals for us.
// If the typeinfo.Value is not nil we know that the expr is resolved // If the typeinfo.Value is not nil we know that the expr is resolved
@ -1150,10 +1154,21 @@ func (c *codegen) emitBinaryExpr(n *ast.BinaryExpr, needJump bool, cond bool, jm
default: default:
ast.Walk(c, n.X) ast.Walk(c, n.X)
ast.Walk(c, n.Y) ast.Walk(c, n.Y)
c.emitToken(n.Op, c.typeOf(n.X)) typ := c.typeOf(n.X)
if needJump { if !needJump {
c.emitJumpOnCondition(cond, jmpLabel) c.emitToken(n.Op, typ)
return
} }
op, ok := getJumpForToken(n.Op, typ)
if !ok {
c.emitToken(n.Op, typ)
c.emitJumpOnCondition(cond, jmpLabel)
return
}
if !cond {
op = negateJmp(op)
}
emit.Jmp(c.prog.BinWriter, op, jmpLabel)
} }
} }
@ -1214,6 +1229,29 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 {
return c.labels[labelWithType{name: name, typ: typ}] return c.labels[labelWithType{name: name, typ: typ}]
} }
// For `&&` and `||` it return an opcode which jumps only if result is known:
// false && .. == false, true || .. = true
func getJumpForToken(tok token.Token, typ types.Type) (opcode.Opcode, bool) {
switch tok {
case token.GTR:
return opcode.JMPGTL, true
case token.GEQ:
return opcode.JMPGEL, true
case token.LSS:
return opcode.JMPLTL, true
case token.LEQ:
return opcode.JMPLEL, true
case token.EQL, token.NEQ:
if isNumber(typ) {
if tok == token.EQL {
return opcode.JMPEQL, true
}
return opcode.JMPNEL, true
}
}
return 0, false
}
// getByteArray returns byte array value from constant expr. // getByteArray returns byte array value from constant expr.
// Only literals are supported. // Only literals are supported.
func (c *codegen) getByteArray(expr ast.Expr) []byte { func (c *codegen) getByteArray(expr ast.Expr) []byte {
@ -1766,6 +1804,29 @@ func calcOffsetCorrection(ip, target int, offsets []int) int {
return cnt return cnt
} }
func negateJmp(op opcode.Opcode) opcode.Opcode {
switch op {
case opcode.JMPIFL:
return opcode.JMPIFNOTL
case opcode.JMPIFNOTL:
return opcode.JMPIFL
case opcode.JMPEQL:
return opcode.JMPNEL
case opcode.JMPNEL:
return opcode.JMPEQL
case opcode.JMPGTL:
return opcode.JMPLEL
case opcode.JMPGEL:
return opcode.JMPLTL
case opcode.JMPLEL:
return opcode.JMPGTL
case opcode.JMPLTL:
return opcode.JMPGEL
default:
panic(fmt.Errorf("invalid opcode in negateJmp: %s", op))
}
}
func toShortForm(op opcode.Opcode) opcode.Opcode { func toShortForm(op opcode.Opcode) opcode.Opcode {
switch op { switch op {
case opcode.JMPL: case opcode.JMPL: