diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 96c3aa68b..c1014cfd7 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -1310,41 +1310,32 @@ func (c *codegen) convertStruct(lit *ast.CompositeLit) { // We will initialize all fields to their "zero" value. for i := 0; i < strct.NumFields(); i++ { sField := strct.Field(i) - fieldAdded := false - - if !keyedLit { - emit.Opcode(c.prog.BinWriter, opcode.DUP) - emit.Int(c.prog.BinWriter, int64(i)) - ast.Walk(c, lit.Elts[i]) - emit.Opcode(c.prog.BinWriter, opcode.SETITEM) - continue - } - - // Fields initialized by the program. - for _, field := range lit.Elts { - f := field.(*ast.KeyValueExpr) - fieldName := f.Key.(*ast.Ident).Name - - if sField.Name() == fieldName { - emit.Opcode(c.prog.BinWriter, opcode.DUP) - - pos := indexOfStruct(strct, fieldName) - emit.Int(c.prog.BinWriter, int64(pos)) - - ast.Walk(c, f.Value) - - emit.Opcode(c.prog.BinWriter, opcode.SETITEM) - fieldAdded = true - break - } - } - if fieldAdded { - continue - } + var initialized bool emit.Opcode(c.prog.BinWriter, opcode.DUP) emit.Int(c.prog.BinWriter, int64(i)) - c.emitDefault(sField.Type()) + + if !keyedLit { + if len(lit.Elts) > i { + ast.Walk(c, lit.Elts[i]) + initialized = true + } + } else { + // Fields initialized by the program. + for _, field := range lit.Elts { + f := field.(*ast.KeyValueExpr) + fieldName := f.Key.(*ast.Ident).Name + + if sField.Name() == fieldName { + ast.Walk(c, f.Value) + initialized = true + break + } + } + } + if !initialized { + c.emitDefault(sField.Type()) + } emit.Opcode(c.prog.BinWriter, opcode.SETITEM) } } diff --git a/pkg/compiler/struct_test.go b/pkg/compiler/struct_test.go index 945273d28..2048d612c 100644 --- a/pkg/compiler/struct_test.go +++ b/pkg/compiler/struct_test.go @@ -405,6 +405,29 @@ var structTestCases = []testCase{ }`, big.NewInt(12), }, + { + "uninitialized struct fields", + `package foo + type Foo struct { + i int + m map[string]int + b []byte + a []int + s struct { ii int } + } + func NewFoo() Foo { return Foo{} } + func Main() int { + foo := NewFoo() + if foo.i != 0 { return 1 } + if len(foo.m) != 0 { return 1 } + if len(foo.b) != 0 { return 1 } + if len(foo.a) != 0 { return 1 } + s := foo.s + if s.ii != 0 { return 1 } + return 2 + }`, + big.NewInt(2), + }, } func TestStructs(t *testing.T) {