Merge pull request #952 from nspcc-dev/fix/structfield

compiler: set default values for complex struct fields (2.x)
This commit is contained in:
Roman Khimov 2020-06-11 17:50:18 +03:00 committed by GitHub
commit 9615e83c00
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 132 additions and 23 deletions

View file

@ -44,8 +44,9 @@ func typeAndValueForField(fld *types.Var) (types.TypeAndValue, error) {
default: default:
return types.TypeAndValue{}, fmt.Errorf("could not initialize struct field %s to zero, type: %s", fld.Name(), t) return types.TypeAndValue{}, fmt.Errorf("could not initialize struct field %s to zero, type: %s", fld.Name(), t)
} }
default:
return types.TypeAndValue{Type: t}, nil
} }
return types.TypeAndValue{}, nil
} }
// countGlobals counts the global variables in the program to add // countGlobals counts the global variables in the program to add

View file

@ -470,7 +470,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
case *ast.SwitchStmt: case *ast.SwitchStmt:
ast.Walk(c, n.Tag) ast.Walk(c, n.Tag)
eqOpcode := c.getEqualityOpcode(n.Tag)
switchEnd, label := c.generateLabel(labelEnd) switchEnd, label := c.generateLabel(labelEnd)
lastSwitch := c.currentSwitch lastSwitch := c.currentSwitch
@ -490,7 +489,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
for j := range cc.List { for j := range cc.List {
emit.Opcode(c.prog.BinWriter, opcode.DUP) emit.Opcode(c.prog.BinWriter, opcode.DUP)
ast.Walk(c, cc.List[j]) ast.Walk(c, cc.List[j])
emit.Opcode(c.prog.BinWriter, eqOpcode) c.emitEquality(n.Tag, token.EQL)
if j == l-1 { if j == l-1 {
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOT, lEnd) emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOT, lEnd)
} else { } else {
@ -533,6 +532,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.emitLoadConst(value) c.emitLoadConst(value)
} else if tv := c.typeInfo.Types[n]; tv.Value != nil { } else if tv := c.typeInfo.Types[n]; tv.Value != nil {
c.emitLoadConst(tv) c.emitLoadConst(tv)
} else if n.Name == "nil" {
c.emitDefault(new(types.Slice))
} else { } else {
c.emitLoadLocal(n.Name) c.emitLoadLocal(n.Name)
} }
@ -615,26 +616,20 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
ast.Walk(c, n.X) ast.Walk(c, n.X)
ast.Walk(c, n.Y) ast.Walk(c, n.Y)
switch { switch n.Op {
case n.Op == token.ADD: case token.ADD:
// VM has separate opcodes for number and string concatenation // VM has separate opcodes for number and string concatenation
if isStringType(tinfo.Type) { if isStringType(tinfo.Type) {
emit.Opcode(c.prog.BinWriter, opcode.CAT) emit.Opcode(c.prog.BinWriter, opcode.CAT)
} else { } else {
emit.Opcode(c.prog.BinWriter, opcode.ADD) emit.Opcode(c.prog.BinWriter, opcode.ADD)
} }
case n.Op == token.EQL: case token.EQL, token.NEQ:
// VM has separate opcodes for number and string equality if isExprNil(n.X) || isExprNil(n.Y) {
op := c.getEqualityOpcode(n.X) c.prog.Err = errors.New("comparison with `nil` is not supported, use `len(..) == 0` instead")
emit.Opcode(c.prog.BinWriter, op) return nil
case n.Op == token.NEQ:
// VM has separate opcodes for number and string equality
if isStringType(c.typeInfo.Types[n.X].Type) {
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
emit.Opcode(c.prog.BinWriter, opcode.NOT)
} else {
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
} }
c.emitEquality(n.X, n.Op)
default: default:
c.convertToken(n.Op) c.convertToken(n.Op)
} }
@ -980,13 +975,26 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 {
return c.labels[labelWithType{name: name, typ: typ}] return c.labels[labelWithType{name: name, typ: typ}]
} }
func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { func (c *codegen) emitEquality(expr ast.Expr, op token.Token) {
t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic) t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic)
if ok && t.Info()&types.IsNumeric != 0 { isNum := ok && t.Info()&types.IsNumeric != 0
return opcode.NUMEQUAL switch op {
case token.EQL:
if isNum {
emit.Opcode(c.prog.BinWriter, opcode.NUMEQUAL)
} else {
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
}
case token.NEQ:
if isNum {
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
} else {
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
emit.Opcode(c.prog.BinWriter, opcode.NOT)
}
default:
panic("invalid token in emitEqual()")
} }
return opcode.EQUAL
} }
// getByteArray returns byte array value from constant expr. // getByteArray returns byte array value from constant expr.
@ -1230,11 +1238,32 @@ func (c *codegen) convertStruct(lit *ast.CompositeLit) {
emit.Opcode(c.prog.BinWriter, opcode.DUP) emit.Opcode(c.prog.BinWriter, opcode.DUP)
emit.Int(c.prog.BinWriter, int64(i)) emit.Int(c.prog.BinWriter, int64(i))
c.emitLoadConst(typeAndVal) c.emitDefault(typeAndVal.Type)
emit.Opcode(c.prog.BinWriter, opcode.SETITEM) emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
} }
} }
func (c *codegen) emitDefault(typ types.Type) {
switch t := c.scTypeFromGo(typ); t {
case "Integer":
emit.Int(c.prog.BinWriter, 0)
case "Boolean":
emit.Bool(c.prog.BinWriter, false)
case "String":
emit.String(c.prog.BinWriter, "")
case "Map":
emit.Opcode(c.prog.BinWriter, opcode.NEWMAP)
case "Struct":
emit.Int(c.prog.BinWriter, int64(typ.(*types.Struct).NumFields()))
emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT)
case "Array":
emit.Int(c.prog.BinWriter, 0)
emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY)
case "ByteArray":
emit.Bytes(c.prog.BinWriter, []byte{})
}
}
func (c *codegen) convertToken(tok token.Token) { func (c *codegen) convertToken(tok token.Token) {
switch tok { switch tok {
case token.ADD_ASSIGN: case token.ADD_ASSIGN:

View file

@ -183,7 +183,11 @@ func (c *codegen) scReturnTypeFromScope(scope *funcScope) string {
} }
func (c *codegen) scTypeFromExpr(typ ast.Expr) string { func (c *codegen) scTypeFromExpr(typ ast.Expr) string {
switch t := c.typeInfo.Types[typ].Type.(type) { return c.scTypeFromGo(c.typeInfo.Types[typ].Type)
}
func (c *codegen) scTypeFromGo(typ types.Type) string {
switch t := typ.(type) {
case *types.Basic: case *types.Basic:
info := t.Info() info := t.Info()
switch { switch {

View file

@ -1,10 +1,14 @@
package compiler_test package compiler_test
import ( import (
"fmt"
"math/big" "math/big"
"strings"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/stretchr/testify/require"
) )
var sliceTestCases = []testCase{ var sliceTestCases = []testCase{
@ -175,6 +179,31 @@ func TestSliceOperations(t *testing.T) {
runTestCases(t, sliceTestCases) runTestCases(t, sliceTestCases)
} }
func TestSliceEmpty(t *testing.T) {
srcTmpl := `package foo
func Main() int {
var a []int
%s
if %s {
return 1
}
return 2
}`
t.Run("WithNil", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "", "a == nil")
_, err := compiler.Compile(strings.NewReader(src))
require.Error(t, err)
})
t.Run("WithLen", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "", "len(a) == 0")
eval(t, src, big.NewInt(1))
})
t.Run("NonEmpty", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "a = []int{1}", "len(a) == 0")
eval(t, src, big.NewInt(2))
})
}
func TestJumps(t *testing.T) { func TestJumps(t *testing.T) {
src := ` src := `
package foo package foo

View file

@ -1,6 +1,7 @@
package compiler_test package compiler_test
import ( import (
"fmt"
"math/big" "math/big"
"testing" "testing"
@ -338,8 +339,53 @@ var structTestCases = []testCase{
}`, }`,
big.NewInt(2), big.NewInt(2),
}, },
{
"uninitialized struct fields",
`package foo
type Foo struct {
i int
m map[string]int
b []byte
a []int
s struct { ii int }
}
func NewFoo() Foo { return Foo{} }
func Main() int {
foo := NewFoo()
if foo.i != 0 { return 1 }
if len(foo.m) != 0 { return 1 }
if len(foo.b) != 0 { return 1 }
if len(foo.a) != 0 { return 1 }
s := foo.s
if s.ii != 0 { return 1 }
return 2
}`,
big.NewInt(2),
},
} }
func TestStructs(t *testing.T) { func TestStructs(t *testing.T) {
runTestCases(t, structTestCases) runTestCases(t, structTestCases)
} }
func TestStructCompare(t *testing.T) {
srcTmpl := `package testcase
type T struct { f int }
func Main() int {
a := T{f: %d}
b := T{f: %d}
if a != b {
return 2
}
return 1
}`
t.Run("Equal", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, 4, 4)
eval(t, src, big.NewInt(1))
})
t.Run("NotEqual", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, 4, 5)
eval(t, src, big.NewInt(2))
})
}