diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index c2382cfd9..f096159e8 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -70,50 +70,28 @@ func (c *codegen) traverseGlobals(f ast.Node) { // countGlobals counts the global variables in the program to add // them with the stack size of the function. -func countGlobals(f ast.Node) (i int64) { +func countGlobals(f ast.Node) (i int) { ast.Inspect(f, func(node ast.Node) bool { - switch node.(type) { + switch n := node.(type) { // Skip all function declarations. case *ast.FuncDecl: return false // After skipping all funcDecls we are sure that each value spec // is a global declared variable or constant. case *ast.ValueSpec: - i++ + i += len(n.Names) } return true }) return } -// isIdentBool looks if the given ident is a boolean. -func isIdentBool(ident *ast.Ident) bool { - return ident.Name == "true" || ident.Name == "false" -} - // isExprNil looks if the given expression is a `nil`. func isExprNil(e ast.Expr) bool { v, ok := e.(*ast.Ident) return ok && v.Name == "nil" } -// makeBoolFromIdent creates a bool type from an *ast.Ident. -func makeBoolFromIdent(ident *ast.Ident, tinfo *types.Info) (types.TypeAndValue, error) { - var b bool - switch ident.Name { - case "true": - b = true - case "false": - b = false - default: - return types.TypeAndValue{}, fmt.Errorf("givent identifier cannot be converted to a boolean => %s", ident.Name) - } - return types.TypeAndValue{ - Type: tinfo.ObjectOf(ident).Type(), - Value: constant.MakeBool(b), - }, nil -} - // resolveEntryPoint returns the function declaration of the entrypoint and the corresponding file. func resolveEntryPoint(entry string, pkg *loader.PackageInfo) (*ast.FuncDecl, *ast.File) { var ( @@ -205,50 +183,6 @@ func isBuiltin(expr ast.Expr) bool { return false } -func (c *codegen) isCompoundArrayType(t ast.Expr) bool { - switch s := t.(type) { - case *ast.ArrayType: - return true - case *ast.Ident: - arr, ok := c.typeInfo.Types[s].Type.Underlying().(*types.Slice) - return ok && !isByte(arr.Elem()) - } - return false -} - -func isByte(t types.Type) bool { - e, ok := t.(*types.Basic) - return ok && e.Kind() == types.Byte -} - -func (c *codegen) isStructType(t ast.Expr) (int, bool) { - switch s := t.(type) { - case *ast.StructType: - return s.Fields.NumFields(), true - case *ast.Ident: - st, ok := c.typeInfo.Types[s].Type.Underlying().(*types.Struct) - if ok { - return st.NumFields(), true - } - } - return 0, false -} - -func isByteArray(lit *ast.CompositeLit, tInfo *types.Info) bool { - if len(lit.Elts) == 0 { - if typ, ok := lit.Type.(*ast.ArrayType); ok { - if name, ok := typ.Elt.(*ast.Ident); ok { - return name.Name == "byte" || name.Name == "uint8" - } - } - - return false - } - - typ := tInfo.Types[lit.Elts[0]].Type.Underlying() - return isByte(typ) -} - func isSyscall(fun *funcScope) bool { if fun.selector == nil { return false @@ -256,11 +190,3 @@ func isSyscall(fun *funcScope) bool { _, ok := syscalls[fun.selector.Name][fun.name] return ok } - -func isByteArrayType(t types.Type) bool { - return t.String() == "[]byte" -} - -func isStringType(t types.Type) bool { - return t.String() == "string" -} diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index ca1766d97..781d5a5a0 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -10,7 +10,6 @@ import ( "go/types" "math" "sort" - "strconv" "strings" "github.com/nspcc-dev/neo-go/pkg/encoding/address" @@ -125,16 +124,13 @@ func (c *codegen) emitLoadConst(t types.TypeAndValue) { if c.prog.Err != nil { return } - switch typ := t.Type.Underlying().(type) { - case *types.Basic: - c.convertBasicType(t, typ) - default: + + typ, ok := t.Type.Underlying().(*types.Basic) + if !ok { c.prog.Err = fmt.Errorf("compiler doesn't know how to convert this constant: %v", t) return } -} -func (c *codegen) convertBasicType(t types.TypeAndValue, typ *types.Basic) { switch typ.Kind() { case types.Int, types.UntypedInt, types.Uint, types.Int16, types.Uint16, @@ -211,6 +207,10 @@ func (c *codegen) emitLoadVar(name string) { // emitStoreVar stores top value from the evaluation stack in the specified variable. func (c *codegen) emitStoreVar(name string) { + if name == "_" { + emit.Opcode(c.prog.BinWriter, opcode.DROP) + return + } t, i := c.getVarIndex(name) _, base := getBaseOpcode(t) if i < 7 { @@ -220,13 +220,9 @@ func (c *codegen) emitStoreVar(name string) { } } -func (c *codegen) emitDefault(n ast.Expr) { - tv, ok := c.typeInfo.Types[n] - if !ok { - c.prog.Err = errors.New("invalid type") - return - } - if t, ok := tv.Type.(*types.Basic); ok { +func (c *codegen) emitDefault(t types.Type) { + switch t := t.Underlying().(type) { + case *types.Basic: info := t.Info() switch { case info&types.IsInteger != 0: @@ -238,9 +234,18 @@ func (c *codegen) emitDefault(n ast.Expr) { default: emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL) } - return + case *types.Slice: + if isCompoundSlice(t) { + emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY0) + } else { + emit.Bytes(c.prog.BinWriter, []byte{}) + } + case *types.Struct: + emit.Int(c.prog.BinWriter, int64(t.NumFields())) + emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT) + default: + emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL) } - emit.Opcode(c.prog.BinWriter, opcode.PUSHNULL) } // convertGlobals traverses the AST and only converts global declarations. @@ -308,15 +313,8 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl) { // to support other types. if decl.Recv != nil { for _, arg := range decl.Recv.List { - ident := arg.Names[0] - // Currently only method receives for struct types is supported. - _, ok := c.typeInfo.Defs[ident].Type().Underlying().(*types.Struct) - if !ok { - c.prog.Err = fmt.Errorf("method receives for non-struct types is not yet supported") - return - } // only create an argument here, it will be stored via INITSLOT - c.scope.newVariable(varArgument, ident.Name) + c.scope.newVariable(varArgument, arg.Names[0].Name) } } @@ -371,24 +369,13 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } c.registerDebugVariable(id.Name, t.Type) } - if len(t.Values) != 0 { - for i, val := range t.Values { - ast.Walk(c, val) - c.emitStoreVar(t.Names[i].Name) - } - } else if c.isCompoundArrayType(t.Type) { - emit.Opcode(c.prog.BinWriter, opcode.PUSH0) - emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY) - c.emitStoreVar(t.Names[0].Name) - } else if n, ok := c.isStructType(t.Type); ok { - emit.Int(c.prog.BinWriter, int64(n)) - emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT) - c.emitStoreVar(t.Names[0].Name) - } else { - for _, id := range t.Names { - c.emitDefault(t.Type) - c.emitStoreVar(id.Name) + for i := range t.Names { + if len(t.Values) != 0 { + ast.Walk(c, t.Values[i]) + } else { + c.emitDefault(c.typeOf(t.Type)) } + c.emitStoreVar(t.Names[i].Name) } } } @@ -397,41 +384,37 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { case *ast.AssignStmt: multiRet := len(n.Rhs) != len(n.Lhs) c.saveSequencePoint(n) + // Assign operations are grouped https://github.com/golang/go/blob/master/src/go/types/stmt.go#L160 + isAssignOp := token.ADD_ASSIGN <= n.Tok && n.Tok <= token.AND_NOT_ASSIGN + if isAssignOp { + // RHS can contain exactly one expression, thus there is no need to iterate. + ast.Walk(c, n.Lhs[0]) + ast.Walk(c, n.Rhs[0]) + c.convertToken(n.Tok) + } for i := 0; i < len(n.Lhs); i++ { switch t := n.Lhs[i].(type) { case *ast.Ident: - switch n.Tok { - case token.ADD_ASSIGN, token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN: - c.emitLoadVar(t.Name) - ast.Walk(c, n.Rhs[0]) // can only add assign to 1 expr on the RHS - c.convertToken(n.Tok) - c.emitStoreVar(t.Name) - case token.DEFINE: + if n.Tok == token.DEFINE { if !multiRet { c.registerDebugVariable(t.Name, n.Rhs[i]) } if t.Name != "_" { c.scope.newLocal(t.Name) } - fallthrough - default: - if i == 0 || !multiRet { - ast.Walk(c, n.Rhs[i]) - } - - if t.Name == "_" { - emit.Opcode(c.prog.BinWriter, opcode.DROP) - } else { - c.emitStoreVar(t.Name) - } } + if !isAssignOp && (i == 0 || !multiRet) { + ast.Walk(c, n.Rhs[i]) + } + c.emitStoreVar(t.Name) case *ast.SelectorExpr: switch expr := t.X.(type) { case *ast.Ident: - ast.Walk(c, n.Rhs[i]) - typ := c.typeInfo.ObjectOf(expr).Type().Underlying() - if strct, ok := typ.(*types.Struct); ok { + if !isAssignOp { + ast.Walk(c, n.Rhs[i]) + } + if strct, ok := c.typeOf(expr).Underlying().(*types.Struct); ok { c.emitLoadVar(expr.Name) // load the struct i := indexOfStruct(strct, t.Sel.Name) // get the index of the field c.emitStoreStructField(i) // store the field @@ -444,26 +427,14 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { // Assignments to index expressions. // slice[0] = 10 case *ast.IndexExpr: - ast.Walk(c, n.Rhs[i]) + if !isAssignOp { + ast.Walk(c, n.Rhs[i]) + } name := t.X.(*ast.Ident).Name c.emitLoadVar(name) - switch ind := t.Index.(type) { - case *ast.BasicLit: - indexStr := ind.Value - index, err := strconv.Atoi(indexStr) - if err != nil { - c.prog.Err = fmt.Errorf("failed to convert slice index to integer") - return nil - } - c.emitStoreStructField(index) - case *ast.Ident: - c.emitLoadVar(ind.Name) - emit.Opcode(c.prog.BinWriter, opcode.ROT) - emit.Opcode(c.prog.BinWriter, opcode.SETITEM) - default: - c.prog.Err = fmt.Errorf("unsupported index expression") - return nil - } + ast.Walk(c, t.Index) + emit.Opcode(c.prog.BinWriter, opcode.ROT) + emit.Opcode(c.prog.BinWriter, opcode.SETITEM) } } return nil @@ -607,18 +578,11 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.BasicLit: - c.emitLoadConst(c.typeInfo.Types[n]) + c.emitLoadConst(c.typeAndValueOf(n)) return nil case *ast.Ident: - if isIdentBool(n) { - value, err := makeBoolFromIdent(n, c.typeInfo) - if err != nil { - c.prog.Err = err - return nil - } - c.emitLoadConst(value) - } else if tv := c.typeInfo.Types[n]; tv.Value != nil { + if tv := c.typeAndValueOf(n); tv.Value != nil { c.emitLoadConst(tv) } else { c.emitLoadVar(n.Name) @@ -626,19 +590,19 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.CompositeLit: - var typ types.Type - - switch t := n.Type.(type) { - case *ast.Ident: - typ = c.typeInfo.ObjectOf(t).Type().Underlying() - case *ast.SelectorExpr: - typ = c.typeInfo.ObjectOf(t.Sel).Type().Underlying() - case *ast.MapType: - typ = c.typeInfo.TypeOf(t) + typ := c.typeOf(n.Type).Underlying() + switch n.Type.(type) { + case *ast.Ident, *ast.SelectorExpr, *ast.MapType: + switch typ.(type) { + case *types.Struct: + c.convertStruct(n) + case *types.Map: + c.convertMap(n) + } default: ln := len(n.Elts) // ByteArrays needs a different approach than normal arrays. - if isByteArray(n, c.typeInfo) { + if isByteSlice(typ) { c.convertByteArray(n) return nil } @@ -647,14 +611,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } emit.Int(c.prog.BinWriter, int64(ln)) emit.Opcode(c.prog.BinWriter, opcode.PACK) - return nil - } - - switch typ.(type) { - case *types.Struct: - c.convertStruct(n) - case *types.Map: - c.convertMap(n) } return nil @@ -693,7 +649,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { // example: // const x = 10 // x + 2 will results into 12 - tinfo := c.typeInfo.Types[n] + tinfo := c.typeAndValueOf(n) if tinfo.Value != nil { c.emitLoadConst(tinfo) return nil @@ -705,7 +661,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { switch { case n.Op == token.ADD: // VM has separate opcodes for number and string concatenation - if isStringType(tinfo.Type) { + if isString(tinfo.Type) { emit.Opcode(c.prog.BinWriter, opcode.CAT) } else { emit.Opcode(c.prog.BinWriter, opcode.ADD) @@ -716,7 +672,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { emit.Opcode(c.prog.BinWriter, op) case n.Op == token.NEQ: // VM has separate opcodes for number and string equality - if isStringType(c.typeInfo.Types[n.X].Type) { + if isString(c.typeOf(n.X)) { emit.Opcode(c.prog.BinWriter, opcode.NOTEQUAL) } else { emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL) @@ -800,8 +756,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { case *ast.SelectorExpr: switch t := n.X.(type) { case *ast.Ident: - typ := c.typeInfo.ObjectOf(t).Type().Underlying() - if strct, ok := typ.(*types.Struct); ok { + if strct, ok := c.typeOf(t).Underlying().(*types.Struct); ok { c.emitLoadVar(t.Name) // load the struct i := indexOfStruct(strct, n.Sel.Name) c.emitLoadField(i) // load the field @@ -849,21 +804,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { // Walk the expression, this could be either an Ident or SelectorExpr. // This will load local whatever X is. ast.Walk(c, n.X) - - switch n.Index.(type) { - case *ast.BasicLit: - t := c.typeInfo.Types[n.Index] - switch typ := t.Type.Underlying().(type) { - case *types.Basic: - c.convertBasicType(t, typ) - default: - c.prog.Err = fmt.Errorf("compiler can't use following type as an index: %T", typ) - return nil - } - default: - ast.Walk(c, n.Index) - } - + ast.Walk(c, n.Index) emit.Opcode(c.prog.BinWriter, opcode.PICKITEM) // just pickitem here return nil @@ -945,13 +886,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.RangeStmt: - // currently only simple for-range loops are supported - // for i := range ... - if n.Value != nil { - c.prog.Err = errors.New("range loops with value variable are not supported") - return nil - } - start, label := c.generateLabel(labelStart) end := c.newNamedLabel(labelEnd, label) post := c.newNamedLabel(labelPost, label) @@ -962,28 +896,30 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { c.currentSwitch = label ast.Walk(c, n.X) + emit.Syscall(c.prog.BinWriter, "Neo.Iterator.Create") - emit.Opcode(c.prog.BinWriter, opcode.SIZE) - emit.Opcode(c.prog.BinWriter, opcode.PUSH0) - - c.pushStackLabel(label, 2) + c.pushStackLabel(label, 1) c.setLabel(start) - emit.Opcode(c.prog.BinWriter, opcode.OVER) - emit.Opcode(c.prog.BinWriter, opcode.OVER) - emit.Opcode(c.prog.BinWriter, opcode.LTE) // finish if len <= i - emit.Jmp(c.prog.BinWriter, opcode.JMPIFL, end) + emit.Opcode(c.prog.BinWriter, opcode.DUP) + emit.Syscall(c.prog.BinWriter, "Neo.Enumerator.Next") + emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOTL, end) if n.Key != nil { emit.Opcode(c.prog.BinWriter, opcode.DUP) + emit.Syscall(c.prog.BinWriter, "Neo.Iterator.Key") c.emitStoreVar(n.Key.(*ast.Ident).Name) } + if n.Value != nil { + emit.Opcode(c.prog.BinWriter, opcode.DUP) + emit.Syscall(c.prog.BinWriter, "Neo.Enumerator.Value") + c.emitStoreVar(n.Value.(*ast.Ident).Name) + } ast.Walk(c, n.Body) c.setLabel(post) - emit.Opcode(c.prog.BinWriter, opcode.INC) emit.Jmp(c.prog.BinWriter, opcode.JMPL, start) c.setLabel(end) @@ -1067,7 +1003,7 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 { } func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { - t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic) + t, ok := c.typeOf(expr).Underlying().(*types.Basic) if ok && t.Info()&types.IsNumeric != 0 { return opcode.NUMEQUAL } @@ -1080,18 +1016,18 @@ func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { func (c *codegen) getByteArray(expr ast.Expr) []byte { switch t := expr.(type) { case *ast.CompositeLit: - if !isByteArray(t, c.typeInfo) { + if !isByteSlice(c.typeOf(t.Type)) { return nil } buf := make([]byte, len(t.Elts)) for i := 0; i < len(t.Elts); i++ { - t := c.typeInfo.Types[t.Elts[i]] + t := c.typeAndValueOf(t.Elts[i]) val, _ := constant.Int64Val(t.Value) buf[i] = byte(val) } return buf case *ast.CallExpr: - if tv := c.typeInfo.Types[t.Args[0]]; tv.Value != nil { + if tv := c.typeAndValueOf(t.Args[0]); tv.Value != nil { val := constant.StringVal(tv.Value) return []byte(val) } @@ -1136,7 +1072,7 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) { case "append": arg := expr.Args[0] typ := c.typeInfo.Types[arg].Type - if isByteArrayType(typ) { + if isByteSlice(typ) { emit.Opcode(c.prog.BinWriter, opcode.CAT) } else { emit.Opcode(c.prog.BinWriter, opcode.OVER) @@ -1148,7 +1084,7 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) { if isExprNil(arg) { emit.Opcode(c.prog.BinWriter, opcode.DROP) emit.Opcode(c.prog.BinWriter, opcode.THROW) - } else if isStringType(c.typeInfo.Types[arg].Type) { + } else if isString(c.typeInfo.Types[arg].Type) { ast.Walk(c, arg) emit.Syscall(c.prog.BinWriter, "Neo.Runtime.Log") emit.Opcode(c.prog.BinWriter, opcode.THROW) @@ -1216,7 +1152,7 @@ func transformArgs(fun ast.Expr, args []ast.Expr) []ast.Expr { func (c *codegen) convertByteArray(lit *ast.CompositeLit) { buf := make([]byte, len(lit.Elts)) for i := 0; i < len(lit.Elts); i++ { - t := c.typeInfo.Types[lit.Elts[i]] + t := c.typeAndValueOf(lit.Elts[i]) val, _ := constant.Int64Val(t.Value) buf[i] = byte(val) } @@ -1237,7 +1173,7 @@ func (c *codegen) convertMap(lit *ast.CompositeLit) { func (c *codegen) convertStruct(lit *ast.CompositeLit) { // Create a new structScope to initialize and store // the positions of its variables. - strct, ok := c.typeInfo.TypeOf(lit).Underlying().(*types.Struct) + strct, ok := c.typeOf(lit).Underlying().(*types.Struct) if !ok { c.prog.Err = fmt.Errorf("the given literal is not of type struct: %v", lit) return diff --git a/pkg/compiler/debug.go b/pkg/compiler/debug.go index 29af3f425..9f174b5cc 100644 --- a/pkg/compiler/debug.go +++ b/pkg/compiler/debug.go @@ -191,7 +191,11 @@ func (c *codegen) scReturnTypeFromScope(scope *funcScope) string { } func (c *codegen) scTypeFromExpr(typ ast.Expr) string { - switch t := c.typeInfo.Types[typ].Type.(type) { + t := c.typeOf(typ) + if c.typeOf(typ) == nil { + return "Any" + } + switch t := t.Underlying().(type) { case *types.Basic: info := t.Info() switch { @@ -209,7 +213,7 @@ func (c *codegen) scTypeFromExpr(typ ast.Expr) string { case *types.Struct: return "Struct" case *types.Slice: - if isByteArrayType(t) { + if isByte(t.Elem()) { return "ByteArray" } return "Array" diff --git a/pkg/compiler/for_test.go b/pkg/compiler/for_test.go index 11d60c507..71160a41f 100644 --- a/pkg/compiler/for_test.go +++ b/pkg/compiler/for_test.go @@ -3,13 +3,9 @@ package compiler_test import ( "fmt" "math/big" - "strings" "testing" - "github.com/nspcc-dev/neo-go/pkg/compiler" - "github.com/nspcc-dev/neo-go/pkg/vm" - "github.com/stretchr/testify/require" ) func TestEntryPointWithMethod(t *testing.T) { @@ -707,20 +703,38 @@ func TestForLoopRangeNoVariable(t *testing.T) { eval(t, src, big.NewInt(3)) } -func TestForLoopRangeCompilerError(t *testing.T) { +func TestForLoopRangeValue(t *testing.T) { src := ` package foo - func f(a int) int { return 0 } + func f(a int) int { return a } func Main() int { - arr := []int{1, 2, 3} + var sum int + arr := []int{1, 9, 4} for _, v := range arr { - f(v) + sum += f(v) } - return 0 + return sum }` - _, err := compiler.Compile(strings.NewReader(src)) - require.Error(t, err) + eval(t, src, big.NewInt(14)) +} + +func TestForLoopRangeMap(t *testing.T) { + src := `package foo + func Main() int { + m := map[int]int{ + 1: 13, + 11: 17, + } + var sum int + for i, v := range m { + sum += i + sum += v + } + return sum + }` + + eval(t, src, big.NewInt(42)) } func TestForLoopComplexConditions(t *testing.T) { diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index d9e81916d..fde6ceadf 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -102,10 +102,14 @@ func (c *funcScope) countLocals() int { case *ast.ReturnStmt, *ast.IfStmt: size++ // This handles the inline GenDecl like "var x = 2" - case *ast.GenDecl: - switch t := n.Specs[0].(type) { - case *ast.ValueSpec: - if len(t.Values) > 0 { + case *ast.ValueSpec: + size += len(n.Names) + case *ast.RangeStmt: + if n.Tok == token.DEFINE { + if n.Key != nil { + size++ + } + if n.Value != nil { size++ } } diff --git a/pkg/compiler/global_test.go b/pkg/compiler/global_test.go index ede29b849..190d50fd5 100644 --- a/pkg/compiler/global_test.go +++ b/pkg/compiler/global_test.go @@ -19,3 +19,39 @@ func TestChangeGlobal(t *testing.T) { eval(t, src, big.NewInt(42)) } + +func TestMultiDeclaration(t *testing.T) { + src := `package foo + var a, b, c int + func Main() int { + a = 1 + b = 2 + c = 3 + return a + b + c + }` + eval(t, src, big.NewInt(6)) +} + +func TestMultiDeclarationLocal(t *testing.T) { + src := `package foo + func Main() int { + var a, b, c int + a = 1 + b = 2 + c = 3 + return a + b + c + }` + eval(t, src, big.NewInt(6)) +} + +func TestMultiDeclarationLocalCompound(t *testing.T) { + src := `package foo + func Main() int { + var a, b, c []int + a = append(a, 1) + b = append(b, 2) + c = append(c, 3) + return a[0] + b[0] + c[0] + }` + eval(t, src, big.NewInt(6)) +} diff --git a/pkg/compiler/slice_test.go b/pkg/compiler/slice_test.go index 40600ea28..629b1b52d 100644 --- a/pkg/compiler/slice_test.go +++ b/pkg/compiler/slice_test.go @@ -33,6 +33,16 @@ var sliceTestCases = []testCase{ `, big.NewInt(42), }, + { + "increase slice element with +=", + `package foo + func Main() int { + a := []int{1, 2, 3} + a[1] += 40 + return a[1] + }`, + big.NewInt(42), + }, { "complex test", ` @@ -130,6 +140,17 @@ var sliceTestCases = []testCase{ }`, []byte{2, 3}, }, + { + "declare byte slice", + `package foo + func Main() []byte { + var a []byte + a = append(a, 1) + a = append(a, 2) + return a + }`, + []byte{1, 2}, + }, { "declare compound slice", `package foo diff --git a/pkg/compiler/struct_test.go b/pkg/compiler/struct_test.go index aa0da998e..f50390899 100644 --- a/pkg/compiler/struct_test.go +++ b/pkg/compiler/struct_test.go @@ -134,6 +134,17 @@ var structTestCases = []testCase{ }`, big.NewInt(14), }, + { + "increase struct field with +=", + `package foo + type token struct { x int } + func Main() int { + t := token{x: 2} + t.x += 3 + return t.x + }`, + big.NewInt(5), + }, { "assign a struct field to a struct field", ` diff --git a/pkg/compiler/type_test.go b/pkg/compiler/type_test.go index c8db6fb0c..b188eeb19 100644 --- a/pkg/compiler/type_test.go +++ b/pkg/compiler/type_test.go @@ -1,6 +1,9 @@ package compiler_test -import "testing" +import ( + "math/big" + "testing" +) func TestCustomType(t *testing.T) { src := ` @@ -22,3 +25,15 @@ func TestCustomType(t *testing.T) { ` eval(t, src, []byte("some short string")) } + +func TestCustomTypeMethods(t *testing.T) { + src := `package foo + type bar int + func (b bar) add(a bar) bar { return a + b } + func Main() bar { + var b bar + b = 10 + return b.add(32) + }` + eval(t, src, big.NewInt(42)) +} diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go new file mode 100644 index 000000000..2011220fc --- /dev/null +++ b/pkg/compiler/types.go @@ -0,0 +1,44 @@ +package compiler + +import ( + "go/ast" + "go/types" +) + +func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue { + return c.typeInfo.Types[e] +} + +func (c *codegen) typeOf(e ast.Expr) types.Type { + return c.typeAndValueOf(e).Type +} + +func isBasicTypeOfKind(typ types.Type, ks ...types.BasicKind) bool { + if t, ok := typ.Underlying().(*types.Basic); ok { + k := t.Kind() + for i := range ks { + if k == ks[i] { + return true + } + } + } + return false +} + +func isByte(typ types.Type) bool { + return isBasicTypeOfKind(typ, types.Uint8, types.Int8) +} + +func isString(typ types.Type) bool { + return isBasicTypeOfKind(typ, types.String) +} + +func isCompoundSlice(typ types.Type) bool { + t, ok := typ.Underlying().(*types.Slice) + return ok && !isByte(t.Elem()) +} + +func isByteSlice(typ types.Type) bool { + t, ok := typ.Underlying().(*types.Slice) + return ok && isByte(t.Elem()) +}