forked from TrueCloudLab/neoneo-go
Merge pull request #952 from nspcc-dev/fix/structfield
compiler: set default values for complex struct fields (2.x)
This commit is contained in:
commit
9615e83c00
5 changed files with 132 additions and 23 deletions
|
@ -44,8 +44,9 @@ func typeAndValueForField(fld *types.Var) (types.TypeAndValue, error) {
|
||||||
default:
|
default:
|
||||||
return types.TypeAndValue{}, fmt.Errorf("could not initialize struct field %s to zero, type: %s", fld.Name(), t)
|
return types.TypeAndValue{}, fmt.Errorf("could not initialize struct field %s to zero, type: %s", fld.Name(), t)
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
return types.TypeAndValue{Type: t}, nil
|
||||||
}
|
}
|
||||||
return types.TypeAndValue{}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// countGlobals counts the global variables in the program to add
|
// countGlobals counts the global variables in the program to add
|
||||||
|
|
|
@ -470,7 +470,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
case *ast.SwitchStmt:
|
case *ast.SwitchStmt:
|
||||||
ast.Walk(c, n.Tag)
|
ast.Walk(c, n.Tag)
|
||||||
|
|
||||||
eqOpcode := c.getEqualityOpcode(n.Tag)
|
|
||||||
switchEnd, label := c.generateLabel(labelEnd)
|
switchEnd, label := c.generateLabel(labelEnd)
|
||||||
|
|
||||||
lastSwitch := c.currentSwitch
|
lastSwitch := c.currentSwitch
|
||||||
|
@ -490,7 +489,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
for j := range cc.List {
|
for j := range cc.List {
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.DUP)
|
emit.Opcode(c.prog.BinWriter, opcode.DUP)
|
||||||
ast.Walk(c, cc.List[j])
|
ast.Walk(c, cc.List[j])
|
||||||
emit.Opcode(c.prog.BinWriter, eqOpcode)
|
c.emitEquality(n.Tag, token.EQL)
|
||||||
if j == l-1 {
|
if j == l-1 {
|
||||||
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOT, lEnd)
|
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOT, lEnd)
|
||||||
} else {
|
} else {
|
||||||
|
@ -533,6 +532,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
c.emitLoadConst(value)
|
c.emitLoadConst(value)
|
||||||
} else if tv := c.typeInfo.Types[n]; tv.Value != nil {
|
} else if tv := c.typeInfo.Types[n]; tv.Value != nil {
|
||||||
c.emitLoadConst(tv)
|
c.emitLoadConst(tv)
|
||||||
|
} else if n.Name == "nil" {
|
||||||
|
c.emitDefault(new(types.Slice))
|
||||||
} else {
|
} else {
|
||||||
c.emitLoadLocal(n.Name)
|
c.emitLoadLocal(n.Name)
|
||||||
}
|
}
|
||||||
|
@ -615,26 +616,20 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
ast.Walk(c, n.X)
|
ast.Walk(c, n.X)
|
||||||
ast.Walk(c, n.Y)
|
ast.Walk(c, n.Y)
|
||||||
|
|
||||||
switch {
|
switch n.Op {
|
||||||
case n.Op == token.ADD:
|
case token.ADD:
|
||||||
// VM has separate opcodes for number and string concatenation
|
// VM has separate opcodes for number and string concatenation
|
||||||
if isStringType(tinfo.Type) {
|
if isStringType(tinfo.Type) {
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.CAT)
|
emit.Opcode(c.prog.BinWriter, opcode.CAT)
|
||||||
} else {
|
} else {
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.ADD)
|
emit.Opcode(c.prog.BinWriter, opcode.ADD)
|
||||||
}
|
}
|
||||||
case n.Op == token.EQL:
|
case token.EQL, token.NEQ:
|
||||||
// VM has separate opcodes for number and string equality
|
if isExprNil(n.X) || isExprNil(n.Y) {
|
||||||
op := c.getEqualityOpcode(n.X)
|
c.prog.Err = errors.New("comparison with `nil` is not supported, use `len(..) == 0` instead")
|
||||||
emit.Opcode(c.prog.BinWriter, op)
|
return nil
|
||||||
case n.Op == token.NEQ:
|
|
||||||
// VM has separate opcodes for number and string equality
|
|
||||||
if isStringType(c.typeInfo.Types[n.X].Type) {
|
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
|
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.NOT)
|
|
||||||
} else {
|
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
|
|
||||||
}
|
}
|
||||||
|
c.emitEquality(n.X, n.Op)
|
||||||
default:
|
default:
|
||||||
c.convertToken(n.Op)
|
c.convertToken(n.Op)
|
||||||
}
|
}
|
||||||
|
@ -980,13 +975,26 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 {
|
||||||
return c.labels[labelWithType{name: name, typ: typ}]
|
return c.labels[labelWithType{name: name, typ: typ}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode {
|
func (c *codegen) emitEquality(expr ast.Expr, op token.Token) {
|
||||||
t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic)
|
t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic)
|
||||||
if ok && t.Info()&types.IsNumeric != 0 {
|
isNum := ok && t.Info()&types.IsNumeric != 0
|
||||||
return opcode.NUMEQUAL
|
switch op {
|
||||||
|
case token.EQL:
|
||||||
|
if isNum {
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.NUMEQUAL)
|
||||||
|
} else {
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
|
||||||
|
}
|
||||||
|
case token.NEQ:
|
||||||
|
if isNum {
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
|
||||||
|
} else {
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.NOT)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("invalid token in emitEqual()")
|
||||||
}
|
}
|
||||||
|
|
||||||
return opcode.EQUAL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getByteArray returns byte array value from constant expr.
|
// getByteArray returns byte array value from constant expr.
|
||||||
|
@ -1230,11 +1238,32 @@ func (c *codegen) convertStruct(lit *ast.CompositeLit) {
|
||||||
|
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.DUP)
|
emit.Opcode(c.prog.BinWriter, opcode.DUP)
|
||||||
emit.Int(c.prog.BinWriter, int64(i))
|
emit.Int(c.prog.BinWriter, int64(i))
|
||||||
c.emitLoadConst(typeAndVal)
|
c.emitDefault(typeAndVal.Type)
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
|
emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *codegen) emitDefault(typ types.Type) {
|
||||||
|
switch t := c.scTypeFromGo(typ); t {
|
||||||
|
case "Integer":
|
||||||
|
emit.Int(c.prog.BinWriter, 0)
|
||||||
|
case "Boolean":
|
||||||
|
emit.Bool(c.prog.BinWriter, false)
|
||||||
|
case "String":
|
||||||
|
emit.String(c.prog.BinWriter, "")
|
||||||
|
case "Map":
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.NEWMAP)
|
||||||
|
case "Struct":
|
||||||
|
emit.Int(c.prog.BinWriter, int64(typ.(*types.Struct).NumFields()))
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT)
|
||||||
|
case "Array":
|
||||||
|
emit.Int(c.prog.BinWriter, 0)
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY)
|
||||||
|
case "ByteArray":
|
||||||
|
emit.Bytes(c.prog.BinWriter, []byte{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *codegen) convertToken(tok token.Token) {
|
func (c *codegen) convertToken(tok token.Token) {
|
||||||
switch tok {
|
switch tok {
|
||||||
case token.ADD_ASSIGN:
|
case token.ADD_ASSIGN:
|
||||||
|
|
|
@ -183,7 +183,11 @@ func (c *codegen) scReturnTypeFromScope(scope *funcScope) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *codegen) scTypeFromExpr(typ ast.Expr) string {
|
func (c *codegen) scTypeFromExpr(typ ast.Expr) string {
|
||||||
switch t := c.typeInfo.Types[typ].Type.(type) {
|
return c.scTypeFromGo(c.typeInfo.Types[typ].Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codegen) scTypeFromGo(typ types.Type) string {
|
||||||
|
switch t := typ.(type) {
|
||||||
case *types.Basic:
|
case *types.Basic:
|
||||||
info := t.Info()
|
info := t.Info()
|
||||||
switch {
|
switch {
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
package compiler_test
|
package compiler_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/compiler"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var sliceTestCases = []testCase{
|
var sliceTestCases = []testCase{
|
||||||
|
@ -175,6 +179,31 @@ func TestSliceOperations(t *testing.T) {
|
||||||
runTestCases(t, sliceTestCases)
|
runTestCases(t, sliceTestCases)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSliceEmpty(t *testing.T) {
|
||||||
|
srcTmpl := `package foo
|
||||||
|
func Main() int {
|
||||||
|
var a []int
|
||||||
|
%s
|
||||||
|
if %s {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 2
|
||||||
|
}`
|
||||||
|
t.Run("WithNil", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, "", "a == nil")
|
||||||
|
_, err := compiler.Compile(strings.NewReader(src))
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
t.Run("WithLen", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, "", "len(a) == 0")
|
||||||
|
eval(t, src, big.NewInt(1))
|
||||||
|
})
|
||||||
|
t.Run("NonEmpty", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, "a = []int{1}", "len(a) == 0")
|
||||||
|
eval(t, src, big.NewInt(2))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestJumps(t *testing.T) {
|
func TestJumps(t *testing.T) {
|
||||||
src := `
|
src := `
|
||||||
package foo
|
package foo
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package compiler_test
|
package compiler_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -338,8 +339,53 @@ var structTestCases = []testCase{
|
||||||
}`,
|
}`,
|
||||||
big.NewInt(2),
|
big.NewInt(2),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"uninitialized struct fields",
|
||||||
|
`package foo
|
||||||
|
type Foo struct {
|
||||||
|
i int
|
||||||
|
m map[string]int
|
||||||
|
b []byte
|
||||||
|
a []int
|
||||||
|
s struct { ii int }
|
||||||
|
}
|
||||||
|
func NewFoo() Foo { return Foo{} }
|
||||||
|
func Main() int {
|
||||||
|
foo := NewFoo()
|
||||||
|
if foo.i != 0 { return 1 }
|
||||||
|
if len(foo.m) != 0 { return 1 }
|
||||||
|
if len(foo.b) != 0 { return 1 }
|
||||||
|
if len(foo.a) != 0 { return 1 }
|
||||||
|
s := foo.s
|
||||||
|
if s.ii != 0 { return 1 }
|
||||||
|
return 2
|
||||||
|
}`,
|
||||||
|
big.NewInt(2),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStructs(t *testing.T) {
|
func TestStructs(t *testing.T) {
|
||||||
runTestCases(t, structTestCases)
|
runTestCases(t, structTestCases)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStructCompare(t *testing.T) {
|
||||||
|
srcTmpl := `package testcase
|
||||||
|
type T struct { f int }
|
||||||
|
func Main() int {
|
||||||
|
a := T{f: %d}
|
||||||
|
b := T{f: %d}
|
||||||
|
if a != b {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}`
|
||||||
|
t.Run("Equal", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, 4, 4)
|
||||||
|
eval(t, src, big.NewInt(1))
|
||||||
|
})
|
||||||
|
t.Run("NotEqual", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, 4, 5)
|
||||||
|
eval(t, src, big.NewInt(2))
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue