diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 3b7771ddf..fcae45fbf 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -561,8 +561,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { lElseEnd := c.newLabel() if n.Cond != nil { - ast.Walk(c, n.Cond) - emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOTL, lElse) + c.emitBoolExpr(n.Cond, true, false, lElse) } c.setLabel(lIf) @@ -689,50 +688,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.BinaryExpr: - // The AST package will try to resolve all basic literals for us. - // If the typeinfo.Value is not nil we know that the expr is resolved - // and needs no further action. e.g. x := 2 + 2 + 2 will be resolved to 6. - // NOTE: Constants will also be automatically resolved be the AST parser. - // example: - // const x = 10 - // x + 2 will results into 12 - tinfo := c.typeAndValueOf(n) - if tinfo.Value != nil { - c.emitLoadConst(tinfo) - return nil - } - - if arg := c.getCompareWithNilArg(n); arg != nil { - ast.Walk(c, arg) - emit.Opcode(c.prog.BinWriter, opcode.ISNULL) - if n.Op == token.NEQ { - emit.Opcode(c.prog.BinWriter, opcode.NOT) - } - return nil - } - - switch n.Op { - case token.LAND, token.LOR: - end := c.newLabel() - ast.Walk(c, n.X) - if n.Op == token.LAND { - emit.Instruction(c.prog.BinWriter, opcode.JMPIF, []byte{2 + 1 + 5}) - emit.Opcode(c.prog.BinWriter, opcode.PUSHF) - } else { - emit.Instruction(c.prog.BinWriter, opcode.JMPIFNOT, []byte{2 + 1 + 5}) - emit.Opcode(c.prog.BinWriter, opcode.PUSHT) - } - emit.Jmp(c.prog.BinWriter, opcode.JMPL, end) - ast.Walk(c, n.Y) - c.setLabel(end) - return nil - - default: - ast.Walk(c, n.X) - ast.Walk(c, n.Y) - c.emitToken(n.Op, c.typeOf(n.X)) - return nil - } + c.emitBinaryExpr(n, false, false, 0) + return nil case *ast.CallExpr: var ( @@ -1122,6 +1079,84 @@ func (c *codegen) getCompareWithNilArg(n *ast.BinaryExpr) ast.Expr { return nil } +func (c *codegen) emitJumpOnCondition(cond bool, jmpLabel uint16) { + if cond { + emit.Jmp(c.prog.BinWriter, opcode.JMPIFL, jmpLabel) + } else { + emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOTL, jmpLabel) + } +} + +func (c *codegen) emitBoolExpr(n ast.Expr, needJump bool, cond bool, jmpLabel uint16) { + if be, ok := n.(*ast.BinaryExpr); ok { + c.emitBinaryExpr(be, needJump, cond, jmpLabel) + } else { + ast.Walk(c, n) + if needJump { + c.emitJumpOnCondition(cond, jmpLabel) + } + } +} + +func (c *codegen) emitBinaryExpr(n *ast.BinaryExpr, needJump bool, cond bool, jmpLabel uint16) { + // The AST package will try to resolve all basic literals for us. + // If the typeinfo.Value is not nil we know that the expr is resolved + // and needs no further action. e.g. x := 2 + 2 + 2 will be resolved to 6. + // NOTE: Constants will also be automatically resolved be the AST parser. + // example: + // const x = 10 + // x + 2 will results into 12 + tinfo := c.typeAndValueOf(n) + if tinfo.Value != nil { + c.emitLoadConst(tinfo) + if needJump && isBool(tinfo.Type) { + c.emitJumpOnCondition(cond, jmpLabel) + } + return + } else if arg := c.getCompareWithNilArg(n); arg != nil { + ast.Walk(c, arg) + emit.Opcode(c.prog.BinWriter, opcode.ISNULL) + if needJump { + c.emitJumpOnCondition(cond == (n.Op == token.EQL), jmpLabel) + } else if n.Op == token.NEQ { + emit.Opcode(c.prog.BinWriter, opcode.NOT) + } + return + } + + switch n.Op { + case token.LAND, token.LOR: + end := c.newLabel() + + // true || .. == true, false && .. == false + condShort := n.Op == token.LOR + if needJump { + l := end + if cond == condShort { + l = jmpLabel + } + c.emitBoolExpr(n.X, true, condShort, l) + c.emitBoolExpr(n.Y, true, cond, jmpLabel) + } else { + push := c.newLabel() + c.emitBoolExpr(n.X, true, condShort, push) + c.emitBoolExpr(n.Y, false, false, 0) + emit.Jmp(c.prog.BinWriter, opcode.JMPL, end) + c.setLabel(push) + emit.Bool(c.prog.BinWriter, condShort) + } + c.setLabel(end) + + default: + ast.Walk(c, n.X) + ast.Walk(c, n.Y) + c.emitToken(n.Op, c.typeOf(n.X)) + if needJump { + c.emitJumpOnCondition(cond, jmpLabel) + } + } +} + func (c *codegen) pushStackLabel(name string, size int) { c.labelList = append(c.labelList, labelWithStackSize{ name: name, diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go index e03dcf286..9b17de395 100644 --- a/pkg/compiler/types.go +++ b/pkg/compiler/types.go @@ -31,6 +31,10 @@ func isByte(typ types.Type) bool { return isBasicTypeOfKind(typ, types.Uint8, types.Int8) } +func isBool(typ types.Type) bool { + return isBasicTypeOfKind(typ, types.Bool, types.UntypedBool) +} + func isNumber(typ types.Type) bool { t, ok := typ.Underlying().(*types.Basic) return ok && t.Info()&types.IsNumeric != 0