diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 6c052e5af..7b3ff55c9 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -44,8 +44,9 @@ func typeAndValueForField(fld *types.Var) (types.TypeAndValue, error) { default: return types.TypeAndValue{}, fmt.Errorf("could not initialize struct field %s to zero, type: %s", fld.Name(), t) } + default: + return types.TypeAndValue{Type: t}, nil } - return types.TypeAndValue{}, nil } // countGlobals counts the global variables in the program to add diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index c53282399..801937d32 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -470,7 +470,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { case *ast.SwitchStmt: ast.Walk(c, n.Tag) - eqOpcode := c.getEqualityOpcode(n.Tag) switchEnd, label := c.generateLabel(labelEnd) lastSwitch := c.currentSwitch @@ -490,7 +489,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { for j := range cc.List { emit.Opcode(c.prog.BinWriter, opcode.DUP) ast.Walk(c, cc.List[j]) - emit.Opcode(c.prog.BinWriter, eqOpcode) + c.emitEquality(n.Tag, token.EQL) if j == l-1 { emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOT, lEnd) } else { @@ -533,6 +532,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { c.emitLoadConst(value) } else if tv := c.typeInfo.Types[n]; tv.Value != nil { c.emitLoadConst(tv) + } else if n.Name == "nil" { + c.emitDefault(new(types.Slice)) } else { c.emitLoadLocal(n.Name) } @@ -615,26 +616,20 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { ast.Walk(c, n.X) ast.Walk(c, n.Y) - switch { - case n.Op == token.ADD: + switch n.Op { + case token.ADD: // VM has separate opcodes for number and string concatenation if isStringType(tinfo.Type) { emit.Opcode(c.prog.BinWriter, opcode.CAT) } else { emit.Opcode(c.prog.BinWriter, opcode.ADD) } - case n.Op == token.EQL: - // VM has separate opcodes for number and string equality - op := c.getEqualityOpcode(n.X) - emit.Opcode(c.prog.BinWriter, op) - case n.Op == token.NEQ: - // VM has separate opcodes for number and string equality - if isStringType(c.typeInfo.Types[n.X].Type) { - emit.Opcode(c.prog.BinWriter, opcode.EQUAL) - emit.Opcode(c.prog.BinWriter, opcode.NOT) - } else { - emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL) + case token.EQL, token.NEQ: + if isExprNil(n.X) || isExprNil(n.Y) { + c.prog.Err = errors.New("comparison with `nil` is not supported, use `len(..) == 0` instead") + return nil } + c.emitEquality(n.X, n.Op) default: c.convertToken(n.Op) } @@ -980,13 +975,26 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 { return c.labels[labelWithType{name: name, typ: typ}] } -func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { +func (c *codegen) emitEquality(expr ast.Expr, op token.Token) { t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic) - if ok && t.Info()&types.IsNumeric != 0 { - return opcode.NUMEQUAL + isNum := ok && t.Info()&types.IsNumeric != 0 + switch op { + case token.EQL: + if isNum { + emit.Opcode(c.prog.BinWriter, opcode.NUMEQUAL) + } else { + emit.Opcode(c.prog.BinWriter, opcode.EQUAL) + } + case token.NEQ: + if isNum { + emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL) + } else { + emit.Opcode(c.prog.BinWriter, opcode.EQUAL) + emit.Opcode(c.prog.BinWriter, opcode.NOT) + } + default: + panic("invalid token in emitEqual()") } - - return opcode.EQUAL } // getByteArray returns byte array value from constant expr. @@ -1230,11 +1238,32 @@ func (c *codegen) convertStruct(lit *ast.CompositeLit) { emit.Opcode(c.prog.BinWriter, opcode.DUP) emit.Int(c.prog.BinWriter, int64(i)) - c.emitLoadConst(typeAndVal) + c.emitDefault(typeAndVal.Type) emit.Opcode(c.prog.BinWriter, opcode.SETITEM) } } +func (c *codegen) emitDefault(typ types.Type) { + switch t := c.scTypeFromGo(typ); t { + case "Integer": + emit.Int(c.prog.BinWriter, 0) + case "Boolean": + emit.Bool(c.prog.BinWriter, false) + case "String": + emit.String(c.prog.BinWriter, "") + case "Map": + emit.Opcode(c.prog.BinWriter, opcode.NEWMAP) + case "Struct": + emit.Int(c.prog.BinWriter, int64(typ.(*types.Struct).NumFields())) + emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT) + case "Array": + emit.Int(c.prog.BinWriter, 0) + emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY) + case "ByteArray": + emit.Bytes(c.prog.BinWriter, []byte{}) + } +} + func (c *codegen) convertToken(tok token.Token) { switch tok { case token.ADD_ASSIGN: diff --git a/pkg/compiler/debug.go b/pkg/compiler/debug.go index 0d607eaae..91d0bbfbb 100644 --- a/pkg/compiler/debug.go +++ b/pkg/compiler/debug.go @@ -183,7 +183,11 @@ func (c *codegen) scReturnTypeFromScope(scope *funcScope) string { } func (c *codegen) scTypeFromExpr(typ ast.Expr) string { - switch t := c.typeInfo.Types[typ].Type.(type) { + return c.scTypeFromGo(c.typeInfo.Types[typ].Type) +} + +func (c *codegen) scTypeFromGo(typ types.Type) string { + switch t := typ.(type) { case *types.Basic: info := t.Info() switch { diff --git a/pkg/compiler/slice_test.go b/pkg/compiler/slice_test.go index ca4c25798..d1f77a369 100644 --- a/pkg/compiler/slice_test.go +++ b/pkg/compiler/slice_test.go @@ -1,10 +1,14 @@ package compiler_test import ( + "fmt" "math/big" + "strings" "testing" + "github.com/nspcc-dev/neo-go/pkg/compiler" "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/stretchr/testify/require" ) var sliceTestCases = []testCase{ @@ -175,6 +179,31 @@ func TestSliceOperations(t *testing.T) { runTestCases(t, sliceTestCases) } +func TestSliceEmpty(t *testing.T) { + srcTmpl := `package foo + func Main() int { + var a []int + %s + if %s { + return 1 + } + return 2 + }` + t.Run("WithNil", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, "", "a == nil") + _, err := compiler.Compile(strings.NewReader(src)) + require.Error(t, err) + }) + t.Run("WithLen", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, "", "len(a) == 0") + eval(t, src, big.NewInt(1)) + }) + t.Run("NonEmpty", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, "a = []int{1}", "len(a) == 0") + eval(t, src, big.NewInt(2)) + }) +} + func TestJumps(t *testing.T) { src := ` package foo diff --git a/pkg/compiler/struct_test.go b/pkg/compiler/struct_test.go index d945ab775..58f685087 100644 --- a/pkg/compiler/struct_test.go +++ b/pkg/compiler/struct_test.go @@ -1,6 +1,7 @@ package compiler_test import ( + "fmt" "math/big" "testing" @@ -338,8 +339,53 @@ var structTestCases = []testCase{ }`, big.NewInt(2), }, + { + "uninitialized struct fields", + `package foo + type Foo struct { + i int + m map[string]int + b []byte + a []int + s struct { ii int } + } + func NewFoo() Foo { return Foo{} } + func Main() int { + foo := NewFoo() + if foo.i != 0 { return 1 } + if len(foo.m) != 0 { return 1 } + if len(foo.b) != 0 { return 1 } + if len(foo.a) != 0 { return 1 } + s := foo.s + if s.ii != 0 { return 1 } + return 2 + }`, + big.NewInt(2), + }, } func TestStructs(t *testing.T) { runTestCases(t, structTestCases) } + +func TestStructCompare(t *testing.T) { + srcTmpl := `package testcase + type T struct { f int } + func Main() int { + a := T{f: %d} + b := T{f: %d} + if a != b { + return 2 + } + return 1 + }` + t.Run("Equal", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, 4, 4) + eval(t, src, big.NewInt(1)) + }) + t.Run("NotEqual", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, 4, 5) + eval(t, src, big.NewInt(2)) + }) + +}