Merge pull request #962 from nspcc-dev/feature/nested

compiler: support nested struct selectors
This commit is contained in:
Roman Khimov 2020-05-21 11:48:59 +03:00 committed by GitHub
commit 1317666167
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 59 deletions

View file

@ -2,9 +2,7 @@ package compiler
import ( import (
"errors" "errors"
"fmt"
"go/ast" "go/ast"
"go/constant"
"go/types" "go/types"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
@ -23,33 +21,6 @@ var (
} }
) )
// typeAndValueForField returns a zero initialized typeAndValue for the given type.Var.
func typeAndValueForField(fld *types.Var) (types.TypeAndValue, error) {
switch t := fld.Type().(type) {
case *types.Basic:
switch t.Kind() {
case types.Int:
return types.TypeAndValue{
Type: t,
Value: constant.MakeInt64(0),
}, nil
case types.String:
return types.TypeAndValue{
Type: t,
Value: constant.MakeString(""),
}, nil
case types.Bool, types.UntypedBool:
return types.TypeAndValue{
Type: t,
Value: constant.MakeBool(false),
}, nil
default:
return types.TypeAndValue{}, fmt.Errorf("could not initialize struct field %s to zero, type: %s", fld.Name(), t)
}
}
return types.TypeAndValue{}, nil
}
// newGlobal creates new global variable. // newGlobal creates new global variable.
func (c *codegen) newGlobal(name string) { func (c *codegen) newGlobal(name string) {
c.globals[name] = len(c.globals) c.globals[name] = len(c.globals)

View file

@ -239,8 +239,15 @@ func (c *codegen) emitDefault(t types.Type) {
emit.Opcode(c.prog.BinWriter, opcode.NEWBUFFER) emit.Opcode(c.prog.BinWriter, opcode.NEWBUFFER)
} }
case *types.Struct: case *types.Struct:
emit.Int(c.prog.BinWriter, int64(t.NumFields())) num := t.NumFields()
emit.Int(c.prog.BinWriter, int64(num))
emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT) emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT)
for i := 0; i < num; i++ {
emit.Opcode(c.prog.BinWriter, opcode.DUP)
emit.Int(c.prog.BinWriter, int64(i))
c.emitDefault(t.Field(i).Type())
emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
}
default: default:
emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL) emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL)
} }
@ -407,20 +414,17 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.emitStoreVar(t.Name) c.emitStoreVar(t.Name)
case *ast.SelectorExpr: case *ast.SelectorExpr:
switch expr := t.X.(type) { if !isAssignOp {
case *ast.Ident: ast.Walk(c, n.Rhs[i])
if !isAssignOp { }
ast.Walk(c, n.Rhs[i]) strct, ok := c.typeOf(t.X).Underlying().(*types.Struct)
} if !ok {
if strct, ok := c.typeOf(expr).Underlying().(*types.Struct); ok {
c.emitLoadVar(expr.Name) // load the struct
i := indexOfStruct(strct, t.Sel.Name) // get the index of the field
c.emitStoreStructField(i) // store the field
}
default:
c.prog.Err = fmt.Errorf("nested selector assigns not supported yet") c.prog.Err = fmt.Errorf("nested selector assigns not supported yet")
return nil return nil
} }
ast.Walk(c, t.X) // load the struct
i := indexOfStruct(strct, t.Sel.Name) // get the index of the field
c.emitStoreStructField(i) // store the field
// Assignments to index expressions. // Assignments to index expressions.
// slice[0] = 10 // slice[0] = 10
@ -428,8 +432,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
if !isAssignOp { if !isAssignOp {
ast.Walk(c, n.Rhs[i]) ast.Walk(c, n.Rhs[i])
} }
name := t.X.(*ast.Ident).Name ast.Walk(c, t.X)
c.emitLoadVar(name)
ast.Walk(c, t.Index) ast.Walk(c, t.Index)
emit.Opcode(c.prog.BinWriter, opcode.ROT) emit.Opcode(c.prog.BinWriter, opcode.ROT)
emit.Opcode(c.prog.BinWriter, opcode.SETITEM) emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
@ -753,17 +756,14 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
case *ast.SelectorExpr: case *ast.SelectorExpr:
switch t := n.X.(type) { strct, ok := c.typeOf(n.X).Underlying().(*types.Struct)
case *ast.Ident: if !ok {
if strct, ok := c.typeOf(t).Underlying().(*types.Struct); ok { c.prog.Err = fmt.Errorf("selectors are supported only on structs")
c.emitLoadVar(t.Name) // load the struct
i := indexOfStruct(strct, n.Sel.Name)
c.emitLoadField(i) // load the field
}
default:
c.prog.Err = fmt.Errorf("nested selectors not supported yet")
return nil return nil
} }
ast.Walk(c, n.X) // load the struct
i := indexOfStruct(strct, n.Sel.Name)
c.emitLoadField(i) // load the field
return nil return nil
case *ast.UnaryExpr: case *ast.UnaryExpr:
@ -1217,15 +1217,9 @@ func (c *codegen) convertStruct(lit *ast.CompositeLit) {
continue continue
} }
typeAndVal, err := typeAndValueForField(sField)
if err != nil {
c.prog.Err = err
return
}
emit.Opcode(c.prog.BinWriter, opcode.DUP) emit.Opcode(c.prog.BinWriter, opcode.DUP)
emit.Int(c.prog.BinWriter, int64(i)) emit.Int(c.prog.BinWriter, int64(i))
c.emitLoadConst(typeAndVal) c.emitDefault(sField.Type())
emit.Opcode(c.prog.BinWriter, opcode.SETITEM) emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
} }
} }

View file

@ -212,6 +212,16 @@ var sliceTestCases = []testCase{
}`, }`,
[]byte{1, 2}, []byte{1, 2},
}, },
{
"nested slice assignment",
`package foo
func Main() int {
a := [][]int{[]int{1, 2}, []int{3, 4}}
a[1][0] = 42
return a[1][0]
}`,
big.NewInt(42),
},
} }
func TestSliceOperations(t *testing.T) { func TestSliceOperations(t *testing.T) {

View file

@ -338,6 +338,62 @@ var structTestCases = []testCase{
}`, }`,
big.NewInt(2), big.NewInt(2),
}, },
{
"nested selectors (simple read)",
`package foo
type S1 struct { x, y S2 }
type S2 struct { a, b int }
func Main() int {
var s1 S1
var s2 S2
s2.a = 3
s1.y = s2
return s1.y.a
}`,
big.NewInt(3),
},
{
"nested selectors (simple write)",
`package foo
type S1 struct { x S2 }
type S2 struct { a int }
func Main() int {
s1 := S1{
x: S2 {
a: 3,
},
}
s1.x.a = 11
return s1.x.a
}`,
big.NewInt(11),
},
{
"complex struct default value",
`package foo
type S1 struct { x S2 }
type S2 struct { y S3 }
type S3 struct { a int }
func Main() int {
var s1 S1
s1.x.y.a = 11
return s1.x.y.a
}`,
big.NewInt(11),
},
{
"nested selectors (complex write)",
`package foo
type S1 struct { x S2 }
type S2 struct { y, z S3 }
type S3 struct { a int }
func Main() int {
var s1 S1
s1.x.y.a, s1.x.z.a = 11, 31
return s1.x.y.a + s1.x.z.a
}`,
big.NewInt(42),
},
} }
func TestStructs(t *testing.T) { func TestStructs(t *testing.T) {