Merge pull request #1166 from nspcc-dev/fix-struct-default-init

Fix struct default init
This commit is contained in:
Roman Khimov 2020-07-09 14:52:00 +03:00 committed by GitHub
commit 2a16df8db1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 44 deletions

View file

@ -231,13 +231,6 @@ func (c *codegen) emitDefault(t types.Type) {
default: default:
emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL) emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL)
} }
case *types.Slice:
if isCompoundSlice(t) {
emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY0)
} else {
emit.Int(c.prog.BinWriter, 0)
emit.Opcode(c.prog.BinWriter, opcode.NEWBUFFER)
}
case *types.Struct: case *types.Struct:
num := t.NumFields() num := t.NumFields()
emit.Int(c.prog.BinWriter, int64(num)) emit.Int(c.prog.BinWriter, int64(num))
@ -1183,12 +1176,29 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) {
case "append": case "append":
arg := expr.Args[0] arg := expr.Args[0]
typ := c.typeInfo.Types[arg].Type typ := c.typeInfo.Types[arg].Type
c.emitReverse(len(expr.Args))
emit.Opcode(c.prog.BinWriter, opcode.DUP)
emit.Opcode(c.prog.BinWriter, opcode.ISNULL)
emit.Instruction(c.prog.BinWriter, opcode.JMPIFNOT, []byte{2 + 3})
if isByteSlice(typ) { if isByteSlice(typ) {
emit.Opcode(c.prog.BinWriter, opcode.CAT) emit.Opcode(c.prog.BinWriter, opcode.DROP)
emit.Opcode(c.prog.BinWriter, opcode.PUSH0)
emit.Opcode(c.prog.BinWriter, opcode.NEWBUFFER)
} else { } else {
emit.Opcode(c.prog.BinWriter, opcode.OVER) emit.Opcode(c.prog.BinWriter, opcode.DROP)
emit.Opcode(c.prog.BinWriter, opcode.SWAP) emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY0)
emit.Opcode(c.prog.BinWriter, opcode.APPEND) emit.Opcode(c.prog.BinWriter, opcode.NOP)
}
// Jump target.
for range expr.Args[1:] {
if isByteSlice(typ) {
emit.Opcode(c.prog.BinWriter, opcode.SWAP)
emit.Opcode(c.prog.BinWriter, opcode.CAT)
} else {
emit.Opcode(c.prog.BinWriter, opcode.DUP)
emit.Opcode(c.prog.BinWriter, opcode.ROT)
emit.Opcode(c.prog.BinWriter, opcode.APPEND)
}
} }
case "panic": case "panic":
arg := expr.Args[0] arg := expr.Args[0]
@ -1310,41 +1320,32 @@ func (c *codegen) convertStruct(lit *ast.CompositeLit) {
// We will initialize all fields to their "zero" value. // We will initialize all fields to their "zero" value.
for i := 0; i < strct.NumFields(); i++ { for i := 0; i < strct.NumFields(); i++ {
sField := strct.Field(i) sField := strct.Field(i)
fieldAdded := false var initialized bool
if !keyedLit {
emit.Opcode(c.prog.BinWriter, opcode.DUP)
emit.Int(c.prog.BinWriter, int64(i))
ast.Walk(c, lit.Elts[i])
emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
continue
}
// Fields initialized by the program.
for _, field := range lit.Elts {
f := field.(*ast.KeyValueExpr)
fieldName := f.Key.(*ast.Ident).Name
if sField.Name() == fieldName {
emit.Opcode(c.prog.BinWriter, opcode.DUP)
pos := indexOfStruct(strct, fieldName)
emit.Int(c.prog.BinWriter, int64(pos))
ast.Walk(c, f.Value)
emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
fieldAdded = true
break
}
}
if fieldAdded {
continue
}
emit.Opcode(c.prog.BinWriter, opcode.DUP) emit.Opcode(c.prog.BinWriter, opcode.DUP)
emit.Int(c.prog.BinWriter, int64(i)) emit.Int(c.prog.BinWriter, int64(i))
c.emitDefault(sField.Type())
if !keyedLit {
if len(lit.Elts) > i {
ast.Walk(c, lit.Elts[i])
initialized = true
}
} else {
// Fields initialized by the program.
for _, field := range lit.Elts {
f := field.(*ast.KeyValueExpr)
fieldName := f.Key.(*ast.Ident).Name
if sField.Name() == fieldName {
ast.Walk(c, f.Value)
initialized = true
break
}
}
}
if !initialized {
c.emitDefault(sField.Type())
}
emit.Opcode(c.prog.BinWriter, opcode.SETITEM) emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
} }
} }
@ -1522,7 +1523,7 @@ func (c *codegen) writeJumps(b []byte) error {
opcode.JMPEQ, opcode.JMPNE, opcode.JMPEQ, opcode.JMPNE,
opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT: opcode.JMPGT, opcode.JMPGE, opcode.JMPLE, opcode.JMPLT:
// Noop, assumed to be correct already. If you're fixing #905, // Noop, assumed to be correct already. If you're fixing #905,
// make sure not to break "len" handling above. // make sure not to break "len" and "append" handling above.
case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL, case opcode.JMPL, opcode.JMPIFL, opcode.JMPIFNOTL,
opcode.JMPEQL, opcode.JMPNEL, opcode.JMPEQL, opcode.JMPNEL,
opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL,

View file

@ -151,6 +151,33 @@ var sliceTestCases = []testCase{
}`, }`,
[]byte{1, 2}, []byte{1, 2},
}, },
{
"append multiple bytes to a slice",
`package foo
func Main() []byte {
var a []byte
a = append(a, 1, 2)
return a
}`,
[]byte{1, 2},
},
{
"append multiple ints to a slice",
`package foo
func Main() []int {
var a []int
a = append(a, 1, 2, 3)
a = append(a, 4, 5)
return a
}`,
[]stackitem.Item{
stackitem.NewBigInteger(big.NewInt(1)),
stackitem.NewBigInteger(big.NewInt(2)),
stackitem.NewBigInteger(big.NewInt(3)),
stackitem.NewBigInteger(big.NewInt(4)),
stackitem.NewBigInteger(big.NewInt(5)),
},
},
{ {
"declare compound slice", "declare compound slice",
`package foo `package foo
@ -243,6 +270,43 @@ var sliceTestCases = []testCase{
}`, }`,
big.NewInt(42), big.NewInt(42),
}, },
{
"defaults to nil for byte slice",
`
package foo
func Main() int {
var a []byte
if a != nil { return 1}
return 2
}
`,
big.NewInt(2),
},
{
"defaults to nil for int slice",
`
package foo
func Main() int {
var a []int
if a != nil { return 1}
return 2
}
`,
big.NewInt(2),
},
{
"defaults to nil for struct slice",
`
package foo
type pair struct { a, b int }
func Main() int {
var a []pair
if a != nil { return 1}
return 2
}
`,
big.NewInt(2),
},
} }
func TestSliceOperations(t *testing.T) { func TestSliceOperations(t *testing.T) {

View file

@ -405,6 +405,29 @@ var structTestCases = []testCase{
}`, }`,
big.NewInt(12), big.NewInt(12),
}, },
{
"uninitialized struct fields",
`package foo
type Foo struct {
i int
m map[string]int
b []byte
a []int
s struct { ii int }
}
func NewFoo() Foo { return Foo{} }
func Main() int {
foo := NewFoo()
if foo.i != 0 { return 1 }
if len(foo.m) != 0 { return 1 }
if len(foo.b) != 0 { return 1 }
if len(foo.a) != 0 { return 1 }
s := foo.s
if s.ii != 0 { return 1 }
return 2
}`,
big.NewInt(2),
},
} }
func TestStructs(t *testing.T) { func TestStructs(t *testing.T) {