diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index bf3272baa..90dd5e318 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -187,6 +187,35 @@ func isBuiltin(expr ast.Expr) bool { return false } +func (c *codegen) isCompoundArrayType(t ast.Expr) bool { + switch s := t.(type) { + case *ast.ArrayType: + return true + case *ast.Ident: + arr, ok := c.typeInfo.Types[s].Type.Underlying().(*types.Slice) + return ok && !isByte(arr.Elem()) + } + return false +} + +func isByte(t types.Type) bool { + e, ok := t.(*types.Basic) + return ok && e.Kind() == types.Byte +} + +func (c *codegen) isStructType(t ast.Expr) (int, bool) { + switch s := t.(type) { + case *ast.StructType: + return s.Fields.NumFields(), true + case *ast.Ident: + st, ok := c.typeInfo.Types[s].Type.Underlying().(*types.Struct) + if ok { + return st.NumFields(), true + } + } + return 0, false +} + func isByteArray(lit *ast.CompositeLit, tInfo *types.Info) bool { if len(lit.Elts) == 0 { if typ, ok := lit.Type.(*ast.ArrayType); ok { @@ -199,14 +228,7 @@ func isByteArray(lit *ast.CompositeLit, tInfo *types.Info) bool { } typ := tInfo.Types[lit.Elts[0]].Type.Underlying() - switch t := typ.(type) { - case *types.Basic: - switch t.Kind() { - case types.Byte: - return true - } - } - return false + return isByte(typ) } func isSyscall(fun *funcScope) bool { diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 5761f94f9..ddd07a48e 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -276,9 +276,21 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { for _, spec := range n.Specs { switch t := spec.(type) { case *ast.ValueSpec: - for i, val := range t.Values { - ast.Walk(c, val) - l := c.scope.newLocal(t.Names[i].Name) + if len(t.Values) != 0 { + for i, val := range t.Values { + ast.Walk(c, val) + l := c.scope.newLocal(t.Names[i].Name) + c.emitStoreLocal(l) + } + } else if c.isCompoundArrayType(t.Type) { + emit.Opcode(c.prog.BinWriter, opcode.PUSH0) + emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY) + l := c.scope.newLocal(t.Names[0].Name) + c.emitStoreLocal(l) + } else if n, ok := c.isStructType(t.Type); ok { + emit.Int(c.prog.BinWriter, int64(n)) + emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT) + l := c.scope.newLocal(t.Names[0].Name) c.emitStoreLocal(l) } } diff --git a/pkg/compiler/slice_test.go b/pkg/compiler/slice_test.go index a9b0ce8f8..40600ea28 100644 --- a/pkg/compiler/slice_test.go +++ b/pkg/compiler/slice_test.go @@ -3,6 +3,8 @@ package compiler_test import ( "math/big" "testing" + + "github.com/nspcc-dev/neo-go/pkg/vm" ) var sliceTestCases = []testCase{ @@ -128,6 +130,35 @@ var sliceTestCases = []testCase{ }`, []byte{2, 3}, }, + { + "declare compound slice", + `package foo + func Main() []string { + var a []string + a = append(a, "a") + a = append(a, "b") + return a + }`, + []vm.StackItem{ + vm.NewByteArrayItem([]byte("a")), + vm.NewByteArrayItem([]byte("b")), + }, + }, + { + "declare compound slice alias", + `package foo + type strs []string + func Main() []string { + var a strs + a = append(a, "a") + a = append(a, "b") + return a + }`, + []vm.StackItem{ + vm.NewByteArrayItem([]byte("a")), + vm.NewByteArrayItem([]byte("b")), + }, + }, } func TestSliceOperations(t *testing.T) { diff --git a/pkg/compiler/struct_test.go b/pkg/compiler/struct_test.go index 964edbbd4..1dd63e07f 100644 --- a/pkg/compiler/struct_test.go +++ b/pkg/compiler/struct_test.go @@ -302,6 +302,31 @@ var structTestCases = []testCase{ `, big.NewInt(14), }, + { + "declare struct literal", + `package foo + func Main() int { + var x struct { + a int + } + x.a = 2 + return x.a + }`, + big.NewInt(2), + }, + { + "declare struct type", + `package foo + type withA struct { + a int + } + func Main() int { + var x withA + x.a = 2 + return x.a + }`, + big.NewInt(2), + }, } func TestStructs(t *testing.T) {