diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 96f9f87e5..b30513dda 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -659,8 +659,26 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil } - ast.Walk(c, n.X) - ast.Walk(c, n.Y) + var checkForNull bool + + if isExprNil(n.X) { + checkForNull = true + } else { + ast.Walk(c, n.X) + } + if isExprNil(n.Y) { + checkForNull = true + } else { + ast.Walk(c, n.Y) + } + if checkForNull { + emit.Opcode(c.prog.BinWriter, opcode.ISNULL) + if n.Op == token.NEQ { + emit.Opcode(c.prog.BinWriter, opcode.NOT) + } + + return nil + } switch { case n.Op == token.ADD: diff --git a/pkg/compiler/nilcheck_test.go b/pkg/compiler/nilcheck_test.go new file mode 100644 index 000000000..54f36e037 --- /dev/null +++ b/pkg/compiler/nilcheck_test.go @@ -0,0 +1,55 @@ +package compiler_test + +import ( + "math/big" + "testing" +) + +var nilTestCases = []testCase{ + { + "nil check positive right", + ` + package foo + func Main() int { + var t interface{} + if t == nil { + return 1 + } + return 2 + } + `, + big.NewInt(1), + }, + { + "nil check negative right", + ` + package foo + func Main() int { + t := []byte{} + if t == nil { + return 1 + } + return 2 + } + `, + big.NewInt(2), + }, + { + "nil check positive left", + ` + package foo + func Main() int { + var t interface{} + if nil == t { + return 1 + } + return 2 + } + `, + big.NewInt(1), + }, +} + +func TestNil(t *testing.T) { + runTestCases(t, nilTestCases) +}