Merge pull request #1351 from nspcc-dev/compiler/jmps

Make use of extended JMP* opcodes
This commit is contained in:
Roman Khimov 2020-08-24 19:22:24 +03:00 committed by GitHub
commit d8badd9a8d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 383 additions and 167 deletions

View file

@ -108,6 +108,16 @@ var assignTestCases = []testCase{
`,
big.NewInt(15),
},
{
"add assign for string",
`package foo
func Main() string {
s := "Hello, "
s += "world!"
return s
}`,
[]byte("Hello, world!"),
},
{
"decl assign",
`

View file

@ -1,6 +1,7 @@
package compiler_test
import (
"fmt"
"math/big"
"testing"
)
@ -262,3 +263,101 @@ var binaryExprTestCases = []testCase{
func TestBinaryExprs(t *testing.T) {
runTestCases(t, binaryExprTestCases)
}
func getBoolExprTestFunc(val bool, cond string) func(t *testing.T) {
srcTmpl := `package foo
var s = "str"
var v = 9
var cond = %s
func Main() int {
if %s {
return 42
} %s
return 17
%s
}`
res := big.NewInt(42)
if !val {
res.SetInt64(17)
}
return func(t *testing.T) {
t.Run("AsExpression", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, cond, "cond", "", "")
eval(t, src, res)
})
t.Run("InCondition", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "true", cond, "", "")
eval(t, src, res)
})
t.Run("InConditionWithElse", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "true", cond, " else {", "}")
eval(t, src, res)
})
}
}
// TestBooleanExprs enumerates a lot of possible combinations of boolean expressions
// and tests if the result matches to that of Go.
func TestBooleanExprs(t *testing.T) {
trueExpr := []string{"true", "v < 10", "v <= 9", "v > 8", "v >= 9", "v == 9", "v != 8", `s == "str"`}
falseExpr := []string{"false", "v > 9", "v >= 10", "v < 9", "v <= 8", "v == 8", "v != 9", `s == "a"`}
t.Run("Single", func(t *testing.T) {
for _, s := range trueExpr {
t.Run(s, getBoolExprTestFunc(true, s))
}
for _, s := range falseExpr {
t.Run(s, getBoolExprTestFunc(false, s))
}
})
type arg struct {
val bool
s string
}
t.Run("Combine", func(t *testing.T) {
var double []arg
for _, e := range trueExpr {
double = append(double, arg{true, e + " || false"})
double = append(double, arg{true, e + " && true"})
}
for _, e := range falseExpr {
double = append(double, arg{false, e + " && true"})
double = append(double, arg{false, e + " || false"})
}
for i := range double {
t.Run(double[i].s, getBoolExprTestFunc(double[i].val, double[i].s))
}
var triple []arg
for _, a1 := range double {
for _, a2 := range double {
triple = append(triple, arg{a1.val || a2.val, fmt.Sprintf("(%s) || (%s)", a1.s, a2.s)})
triple = append(triple, arg{a1.val && a2.val, fmt.Sprintf("(%s) && (%s)", a1.s, a2.s)})
}
}
for i := range triple {
t.Run(triple[i].s, getBoolExprTestFunc(triple[i].val, triple[i].s))
}
})
return
}
func TestShortCircuit(t *testing.T) {
srcTmpl := `package foo
var a = 1
func inc() bool { a += 1; return %s }
func Main() int {
if %s {
return 41 + a
}
return 16 + a
}`
t.Run("||", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "true", "a == 1 || inc()")
eval(t, src, big.NewInt(42))
})
t.Run("&&", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "false", "a == 2 && inc()")
eval(t, src, big.NewInt(17))
})
}

View file

@ -446,7 +446,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
// RHS can contain exactly one expression, thus there is no need to iterate.
ast.Walk(c, n.Lhs[0])
ast.Walk(c, n.Rhs[0])
c.convertToken(n.Tok)
c.emitToken(n.Tok, c.typeOf(n.Rhs[0]))
}
for i := 0; i < len(n.Lhs); i++ {
switch t := n.Lhs[i].(type) {
@ -561,8 +561,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
lElseEnd := c.newLabel()
if n.Cond != nil {
ast.Walk(c, n.Cond)
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOTL, lElse)
c.emitBoolExpr(n.Cond, true, false, lElse)
}
c.setLabel(lIf)
@ -581,7 +580,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
case *ast.SwitchStmt:
ast.Walk(c, n.Tag)
eqOpcode := c.getEqualityOpcode(n.Tag)
eqOpcode, _ := convertToken(token.EQL, c.typeOf(n.Tag))
switchEnd, label := c.generateLabel(labelEnd)
lastSwitch := c.currentSwitch
@ -689,86 +688,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil
case *ast.BinaryExpr:
switch n.Op {
case token.LAND:
end := c.newLabel()
ast.Walk(c, n.X)
emit.Instruction(c.prog.BinWriter, opcode.JMPIF, []byte{2 + 1 + 5})
emit.Opcode(c.prog.BinWriter, opcode.PUSHF)
emit.Jmp(c.prog.BinWriter, opcode.JMPL, end)
ast.Walk(c, n.Y)
c.setLabel(end)
return nil
case token.LOR:
end := c.newLabel()
ast.Walk(c, n.X)
emit.Instruction(c.prog.BinWriter, opcode.JMPIFNOT, []byte{2 + 1 + 5})
emit.Opcode(c.prog.BinWriter, opcode.PUSHT)
emit.Jmp(c.prog.BinWriter, opcode.JMPL, end)
ast.Walk(c, n.Y)
c.setLabel(end)
return nil
default:
// The AST package will try to resolve all basic literals for us.
// If the typeinfo.Value is not nil we know that the expr is resolved
// and needs no further action. e.g. x := 2 + 2 + 2 will be resolved to 6.
// NOTE: Constants will also be automatically resolved be the AST parser.
// example:
// const x = 10
// x + 2 will results into 12
tinfo := c.typeAndValueOf(n)
if tinfo.Value != nil {
c.emitLoadConst(tinfo)
return nil
}
var checkForNull bool
if isExprNil(n.X) {
checkForNull = true
} else {
ast.Walk(c, n.X)
}
if isExprNil(n.Y) {
checkForNull = true
} else {
ast.Walk(c, n.Y)
}
if checkForNull {
emit.Opcode(c.prog.BinWriter, opcode.ISNULL)
if n.Op == token.NEQ {
emit.Opcode(c.prog.BinWriter, opcode.NOT)
}
return nil
}
switch {
case n.Op == token.ADD:
// VM has separate opcodes for number and string concatenation
if isString(tinfo.Type) {
emit.Opcode(c.prog.BinWriter, opcode.CAT)
} else {
emit.Opcode(c.prog.BinWriter, opcode.ADD)
}
case n.Op == token.EQL:
// VM has separate opcodes for number and string equality
op := c.getEqualityOpcode(n.X)
emit.Opcode(c.prog.BinWriter, op)
case n.Op == token.NEQ:
// VM has separate opcodes for number and string equality
if isString(c.typeOf(n.X)) {
emit.Opcode(c.prog.BinWriter, opcode.NOTEQUAL)
} else {
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
}
default:
c.convertToken(n.Op)
}
return nil
}
c.emitBinaryExpr(n, false, false, 0)
return nil
case *ast.CallExpr:
var (
@ -935,7 +856,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
case *ast.IncDecStmt:
ast.Walk(c, n.X)
c.convertToken(n.Tok)
c.emitToken(n.Tok, c.typeOf(n.X))
// For now only identifiers are supported for (post) for stmts.
// for i := 0; i < 10; i++ {}
@ -1149,6 +1070,108 @@ func isFallthroughStmt(c ast.Node) bool {
return ok && s.Tok == token.FALLTHROUGH
}
func (c *codegen) getCompareWithNilArg(n *ast.BinaryExpr) ast.Expr {
if isExprNil(n.X) {
return n.Y
} else if isExprNil(n.Y) {
return n.X
}
return nil
}
func (c *codegen) emitJumpOnCondition(cond bool, jmpLabel uint16) {
if cond {
emit.Jmp(c.prog.BinWriter, opcode.JMPIFL, jmpLabel)
} else {
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOTL, jmpLabel)
}
}
// emitBoolExpr emits boolean expression. If needJump is true and expression evaluates to `cond`,
// jump to jmpLabel is performed and no item is left on stack.
func (c *codegen) emitBoolExpr(n ast.Expr, needJump bool, cond bool, jmpLabel uint16) {
if be, ok := n.(*ast.BinaryExpr); ok {
c.emitBinaryExpr(be, needJump, cond, jmpLabel)
} else {
ast.Walk(c, n)
if needJump {
c.emitJumpOnCondition(cond, jmpLabel)
}
}
}
// emitBinaryExpr emits binary expression. If needJump is true and expression evaluates to `cond`,
// jump to jmpLabel is performed and no item is left on stack.
func (c *codegen) emitBinaryExpr(n *ast.BinaryExpr, needJump bool, cond bool, jmpLabel uint16) {
// The AST package will try to resolve all basic literals for us.
// If the typeinfo.Value is not nil we know that the expr is resolved
// and needs no further action. e.g. x := 2 + 2 + 2 will be resolved to 6.
// NOTE: Constants will also be automatically resolved be the AST parser.
// example:
// const x = 10
// x + 2 will results into 12
tinfo := c.typeAndValueOf(n)
if tinfo.Value != nil {
c.emitLoadConst(tinfo)
if needJump && isBool(tinfo.Type) {
c.emitJumpOnCondition(cond, jmpLabel)
}
return
} else if arg := c.getCompareWithNilArg(n); arg != nil {
ast.Walk(c, arg)
emit.Opcode(c.prog.BinWriter, opcode.ISNULL)
if needJump {
c.emitJumpOnCondition(cond == (n.Op == token.EQL), jmpLabel)
} else if n.Op == token.NEQ {
emit.Opcode(c.prog.BinWriter, opcode.NOT)
}
return
}
switch n.Op {
case token.LAND, token.LOR:
end := c.newLabel()
// true || .. == true, false && .. == false
condShort := n.Op == token.LOR
if needJump {
l := end
if cond == condShort {
l = jmpLabel
}
c.emitBoolExpr(n.X, true, condShort, l)
c.emitBoolExpr(n.Y, true, cond, jmpLabel)
} else {
push := c.newLabel()
c.emitBoolExpr(n.X, true, condShort, push)
c.emitBoolExpr(n.Y, false, false, 0)
emit.Jmp(c.prog.BinWriter, opcode.JMPL, end)
c.setLabel(push)
emit.Bool(c.prog.BinWriter, condShort)
}
c.setLabel(end)
default:
ast.Walk(c, n.X)
ast.Walk(c, n.Y)
typ := c.typeOf(n.X)
if !needJump {
c.emitToken(n.Op, typ)
return
}
op, ok := getJumpForToken(n.Op, typ)
if !ok {
c.emitToken(n.Op, typ)
c.emitJumpOnCondition(cond, jmpLabel)
return
}
if !cond {
op = negateJmp(op)
}
emit.Jmp(c.prog.BinWriter, op, jmpLabel)
}
}
func (c *codegen) pushStackLabel(name string, size int) {
c.labelList = append(c.labelList, labelWithStackSize{
name: name,
@ -1206,13 +1229,27 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 {
return c.labels[labelWithType{name: name, typ: typ}]
}
func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode {
t, ok := c.typeOf(expr).Underlying().(*types.Basic)
if ok && t.Info()&types.IsNumeric != 0 {
return opcode.NUMEQUAL
// For `&&` and `||` it return an opcode which jumps only if result is known:
// false && .. == false, true || .. = true
func getJumpForToken(tok token.Token, typ types.Type) (opcode.Opcode, bool) {
switch tok {
case token.GTR:
return opcode.JMPGTL, true
case token.GEQ:
return opcode.JMPGEL, true
case token.LSS:
return opcode.JMPLTL, true
case token.LEQ:
return opcode.JMPLEL, true
case token.EQL, token.NEQ:
if isNumber(typ) {
if tok == token.EQL {
return opcode.JMPEQL, true
}
return opcode.JMPNEL, true
}
}
return opcode.EQUAL
return 0, false
}
// getByteArray returns byte array value from constant expr.
@ -1460,60 +1497,78 @@ func (c *codegen) convertStruct(lit *ast.CompositeLit, ptr bool) {
}
}
func (c *codegen) convertToken(tok token.Token) {
switch tok {
case token.ADD_ASSIGN:
emit.Opcode(c.prog.BinWriter, opcode.ADD)
case token.SUB_ASSIGN:
emit.Opcode(c.prog.BinWriter, opcode.SUB)
case token.MUL_ASSIGN:
emit.Opcode(c.prog.BinWriter, opcode.MUL)
case token.QUO_ASSIGN:
emit.Opcode(c.prog.BinWriter, opcode.DIV)
case token.REM_ASSIGN:
emit.Opcode(c.prog.BinWriter, opcode.MOD)
case token.ADD:
emit.Opcode(c.prog.BinWriter, opcode.ADD)
case token.SUB:
emit.Opcode(c.prog.BinWriter, opcode.SUB)
case token.MUL:
emit.Opcode(c.prog.BinWriter, opcode.MUL)
case token.QUO:
emit.Opcode(c.prog.BinWriter, opcode.DIV)
case token.REM:
emit.Opcode(c.prog.BinWriter, opcode.MOD)
case token.LSS:
emit.Opcode(c.prog.BinWriter, opcode.LT)
case token.LEQ:
emit.Opcode(c.prog.BinWriter, opcode.LTE)
case token.GTR:
emit.Opcode(c.prog.BinWriter, opcode.GT)
case token.GEQ:
emit.Opcode(c.prog.BinWriter, opcode.GTE)
case token.EQL:
emit.Opcode(c.prog.BinWriter, opcode.NUMEQUAL)
case token.NEQ:
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
case token.DEC:
emit.Opcode(c.prog.BinWriter, opcode.DEC)
case token.INC:
emit.Opcode(c.prog.BinWriter, opcode.INC)
case token.NOT:
emit.Opcode(c.prog.BinWriter, opcode.NOT)
case token.AND:
emit.Opcode(c.prog.BinWriter, opcode.AND)
case token.OR:
emit.Opcode(c.prog.BinWriter, opcode.OR)
case token.SHL:
emit.Opcode(c.prog.BinWriter, opcode.SHL)
case token.SHR:
emit.Opcode(c.prog.BinWriter, opcode.SHR)
case token.XOR:
emit.Opcode(c.prog.BinWriter, opcode.XOR)
default:
c.prog.Err = fmt.Errorf("compiler could not convert token: %s", tok)
func (c *codegen) emitToken(tok token.Token, typ types.Type) {
op, err := convertToken(tok, typ)
if err != nil {
c.prog.Err = err
return
}
emit.Opcode(c.prog.BinWriter, op)
}
func convertToken(tok token.Token, typ types.Type) (opcode.Opcode, error) {
switch tok {
case token.ADD_ASSIGN, token.ADD:
// VM has separate opcodes for number and string concatenation
if isString(typ) {
return opcode.CAT, nil
}
return opcode.ADD, nil
case token.SUB_ASSIGN:
return opcode.SUB, nil
case token.MUL_ASSIGN:
return opcode.MUL, nil
case token.QUO_ASSIGN:
return opcode.DIV, nil
case token.REM_ASSIGN:
return opcode.MOD, nil
case token.SUB:
return opcode.SUB, nil
case token.MUL:
return opcode.MUL, nil
case token.QUO:
return opcode.DIV, nil
case token.REM:
return opcode.MOD, nil
case token.LSS:
return opcode.LT, nil
case token.LEQ:
return opcode.LTE, nil
case token.GTR:
return opcode.GT, nil
case token.GEQ:
return opcode.GTE, nil
case token.EQL:
// VM has separate opcodes for number and string equality
if isNumber(typ) {
return opcode.NUMEQUAL, nil
}
return opcode.EQUAL, nil
case token.NEQ:
// VM has separate opcodes for number and string equality
if isNumber(typ) {
return opcode.NUMNOTEQUAL, nil
}
return opcode.NOTEQUAL, nil
case token.DEC:
return opcode.DEC, nil
case token.INC:
return opcode.INC, nil
case token.NOT:
return opcode.NOT, nil
case token.AND:
return opcode.AND, nil
case token.OR:
return opcode.OR, nil
case token.SHL:
return opcode.SHL, nil
case token.SHR:
return opcode.SHR, nil
case token.XOR:
return opcode.XOR, nil
default:
return 0, fmt.Errorf("compiler could not convert token: %s", tok)
}
}
func (c *codegen) newFunc(decl *ast.FuncDecl) *funcScope {
@ -1717,13 +1772,13 @@ func shortenJumps(b []byte, offsets []int) []byte {
// 2. Convert instructions.
copyOffset := 0
l := len(offsets)
b[offsets[0]] = toShortForm(b[offsets[0]])
b[offsets[0]] = byte(toShortForm(opcode.Opcode(b[offsets[0]])))
for i := 0; i < l; i++ {
start := offsets[i] + 2
end := len(b)
if i != l-1 {
end = offsets[i+1] + 2
b[offsets[i+1]] = toShortForm(b[offsets[i+1]])
b[offsets[i+1]] = byte(toShortForm(opcode.Opcode(b[offsets[i+1]])))
}
copy(b[start-copyOffset:], b[start+3:end])
copyOffset += longToShortRemoveCount
@ -1749,28 +1804,51 @@ func calcOffsetCorrection(ip, target int, offsets []int) int {
return cnt
}
func toShortForm(b byte) byte {
switch op := opcode.Opcode(b); op {
case opcode.JMPL:
return byte(opcode.JMP)
func negateJmp(op opcode.Opcode) opcode.Opcode {
switch op {
case opcode.JMPIFL:
return byte(opcode.JMPIF)
return opcode.JMPIFNOTL
case opcode.JMPIFNOTL:
return byte(opcode.JMPIFNOT)
return opcode.JMPIFL
case opcode.JMPEQL:
return byte(opcode.JMPEQ)
return opcode.JMPNEL
case opcode.JMPNEL:
return byte(opcode.JMPNE)
return opcode.JMPEQL
case opcode.JMPGTL:
return byte(opcode.JMPGT)
return opcode.JMPLEL
case opcode.JMPGEL:
return byte(opcode.JMPGE)
return opcode.JMPLTL
case opcode.JMPLEL:
return byte(opcode.JMPLE)
return opcode.JMPGTL
case opcode.JMPLTL:
return byte(opcode.JMPLT)
return opcode.JMPGEL
default:
panic(fmt.Errorf("invalid opcode in negateJmp: %s", op))
}
}
func toShortForm(op opcode.Opcode) opcode.Opcode {
switch op {
case opcode.JMPL:
return opcode.JMP
case opcode.JMPIFL:
return opcode.JMPIF
case opcode.JMPIFNOTL:
return opcode.JMPIFNOT
case opcode.JMPEQL:
return opcode.JMPEQ
case opcode.JMPNEL:
return opcode.JMPNE
case opcode.JMPGTL:
return opcode.JMPGT
case opcode.JMPGEL:
return opcode.JMPGE
case opcode.JMPLEL:
return opcode.JMPLE
case opcode.JMPLTL:
return opcode.JMPLT
case opcode.CALLL:
return byte(opcode.CALL)
return opcode.CALL
default:
panic(fmt.Errorf("invalid opcode: %s", op))
}

View file

@ -2,9 +2,9 @@ package compiler
import (
"go/token"
"go/types"
"testing"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/stretchr/testify/assert"
)
@ -14,59 +14,79 @@ func TestConvertToken(t *testing.T) {
name string
token token.Token
opcode opcode.Opcode
typ types.Type
}
testCases := []testCase{
{"ADD",
{"ADD (number)",
token.ADD,
opcode.ADD,
types.Typ[types.Int],
},
{"ADD (string)",
token.ADD,
opcode.CAT,
types.Typ[types.String],
},
{"SUB",
token.SUB,
opcode.SUB,
nil,
},
{"MUL",
token.MUL,
opcode.MUL,
nil,
},
{"QUO",
token.QUO,
opcode.DIV,
nil,
},
{"REM",
token.REM,
opcode.MOD,
nil,
},
{"ADD_ASSIGN",
{"ADD_ASSIGN (number)",
token.ADD_ASSIGN,
opcode.ADD,
types.Typ[types.Int],
},
{"ADD_ASSIGN (string)",
token.ADD_ASSIGN,
opcode.CAT,
types.Typ[types.String],
},
{"SUB_ASSIGN",
token.SUB_ASSIGN,
opcode.SUB,
nil,
},
{"MUL_ASSIGN",
token.MUL_ASSIGN,
opcode.MUL,
nil,
},
{"QUO_ASSIGN",
token.QUO_ASSIGN,
opcode.DIV,
nil,
},
{"REM_ASSIGN",
token.REM_ASSIGN,
opcode.MOD,
nil,
},
}
for _, tcase := range testCases {
t.Run(tcase.name, func(t *testing.T) { eval(t, tcase.token, tcase.opcode) })
t.Run(tcase.name, func(t *testing.T) { eval(t, tcase.token, tcase.opcode, tcase.typ) })
}
}
func eval(t *testing.T, token token.Token, opcode opcode.Opcode) {
codegen := &codegen{prog: io.NewBufBinWriter()}
codegen.convertToken(token)
readOpcode := codegen.prog.Bytes()
assert.Equal(t, []byte{byte(opcode)}, readOpcode)
func eval(t *testing.T, token token.Token, opcode opcode.Opcode, typ types.Type) {
op, err := convertToken(token, typ)
assert.NoError(t, err)
assert.Equal(t, opcode, op)
}

View file

@ -31,6 +31,15 @@ func isByte(typ types.Type) bool {
return isBasicTypeOfKind(typ, types.Uint8, types.Int8)
}
func isBool(typ types.Type) bool {
return isBasicTypeOfKind(typ, types.Bool, types.UntypedBool)
}
func isNumber(typ types.Type) bool {
t, ok := typ.Underlying().(*types.Basic)
return ok && t.Info()&types.IsNumeric != 0
}
func isString(typ types.Type) bool {
return isBasicTypeOfKind(typ, types.String)
}