From 70d0ff869d4f5a1f797fabc2cb9cd009bbb6a424 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Wed, 13 May 2020 18:09:55 +0300 Subject: [PATCH] compiler: refactor typeinfo functions --- pkg/compiler/analysis.go | 52 -------------------- pkg/compiler/codegen.go | 103 ++++++++++++++++----------------------- pkg/compiler/debug.go | 8 ++- pkg/compiler/types.go | 44 +++++++++++++++++ 4 files changed, 93 insertions(+), 114 deletions(-) create mode 100644 pkg/compiler/types.go diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index c2382cfd9..b79d1de85 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -205,50 +205,6 @@ func isBuiltin(expr ast.Expr) bool { return false } -func (c *codegen) isCompoundArrayType(t ast.Expr) bool { - switch s := t.(type) { - case *ast.ArrayType: - return true - case *ast.Ident: - arr, ok := c.typeInfo.Types[s].Type.Underlying().(*types.Slice) - return ok && !isByte(arr.Elem()) - } - return false -} - -func isByte(t types.Type) bool { - e, ok := t.(*types.Basic) - return ok && e.Kind() == types.Byte -} - -func (c *codegen) isStructType(t ast.Expr) (int, bool) { - switch s := t.(type) { - case *ast.StructType: - return s.Fields.NumFields(), true - case *ast.Ident: - st, ok := c.typeInfo.Types[s].Type.Underlying().(*types.Struct) - if ok { - return st.NumFields(), true - } - } - return 0, false -} - -func isByteArray(lit *ast.CompositeLit, tInfo *types.Info) bool { - if len(lit.Elts) == 0 { - if typ, ok := lit.Type.(*ast.ArrayType); ok { - if name, ok := typ.Elt.(*ast.Ident); ok { - return name.Name == "byte" || name.Name == "uint8" - } - } - - return false - } - - typ := tInfo.Types[lit.Elts[0]].Type.Underlying() - return isByte(typ) -} - func isSyscall(fun *funcScope) bool { if fun.selector == nil { return false @@ -256,11 +212,3 @@ func isSyscall(fun *funcScope) bool { _, ok := syscalls[fun.selector.Name][fun.name] return ok } - -func isByteArrayType(t types.Type) bool { - return t.String() == "[]byte" -} - -func isStringType(t types.Type) bool { - return t.String() == "string" -} diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index ca1766d97..54ca11b4d 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -125,16 +125,13 @@ func (c *codegen) emitLoadConst(t types.TypeAndValue) { if c.prog.Err != nil { return } - switch typ := t.Type.Underlying().(type) { - case *types.Basic: - c.convertBasicType(t, typ) - default: + + typ, ok := t.Type.Underlying().(*types.Basic) + if !ok { c.prog.Err = fmt.Errorf("compiler doesn't know how to convert this constant: %v", t) return } -} -func (c *codegen) convertBasicType(t types.TypeAndValue, typ *types.Basic) { switch typ.Kind() { case types.Int, types.UntypedInt, types.Uint, types.Int16, types.Uint16, @@ -376,18 +373,21 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { ast.Walk(c, val) c.emitStoreVar(t.Names[i].Name) } - } else if c.isCompoundArrayType(t.Type) { - emit.Opcode(c.prog.BinWriter, opcode.PUSH0) - emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY) - c.emitStoreVar(t.Names[0].Name) - } else if n, ok := c.isStructType(t.Type); ok { - emit.Int(c.prog.BinWriter, int64(n)) - emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT) - c.emitStoreVar(t.Names[0].Name) } else { - for _, id := range t.Names { - c.emitDefault(t.Type) - c.emitStoreVar(id.Name) + typ := c.typeOf(t.Type) + if isCompoundSlice(typ) { + emit.Opcode(c.prog.BinWriter, opcode.PUSH0) + emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY) + c.emitStoreVar(t.Names[0].Name) + } else if s, ok := typ.Underlying().(*types.Struct); ok { + emit.Int(c.prog.BinWriter, int64(s.NumFields())) + emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT) + c.emitStoreVar(t.Names[0].Name) + } else { + for _, id := range t.Names { + c.emitDefault(t.Type) + c.emitStoreVar(id.Name) + } } } } @@ -430,8 +430,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { switch expr := t.X.(type) { case *ast.Ident: ast.Walk(c, n.Rhs[i]) - typ := c.typeInfo.ObjectOf(expr).Type().Underlying() - if strct, ok := typ.(*types.Struct); ok { + if strct, ok := c.typeOf(expr).Underlying().(*types.Struct); ok { c.emitLoadVar(expr.Name) // load the struct i := indexOfStruct(strct, t.Sel.Name) // get the index of the field c.emitStoreStructField(i) // store the field @@ -607,7 +606,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.BasicLit: - c.emitLoadConst(c.typeInfo.Types[n]) + c.emitLoadConst(c.typeAndValueOf(n)) return nil case *ast.Ident: @@ -618,7 +617,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil } c.emitLoadConst(value) - } else if tv := c.typeInfo.Types[n]; tv.Value != nil { + } else if tv := c.typeAndValueOf(n); tv.Value != nil { c.emitLoadConst(tv) } else { c.emitLoadVar(n.Name) @@ -626,19 +625,19 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.CompositeLit: - var typ types.Type - - switch t := n.Type.(type) { - case *ast.Ident: - typ = c.typeInfo.ObjectOf(t).Type().Underlying() - case *ast.SelectorExpr: - typ = c.typeInfo.ObjectOf(t.Sel).Type().Underlying() - case *ast.MapType: - typ = c.typeInfo.TypeOf(t) + typ := c.typeOf(n.Type).Underlying() + switch n.Type.(type) { + case *ast.Ident, *ast.SelectorExpr, *ast.MapType: + switch typ.(type) { + case *types.Struct: + c.convertStruct(n) + case *types.Map: + c.convertMap(n) + } default: ln := len(n.Elts) // ByteArrays needs a different approach than normal arrays. - if isByteArray(n, c.typeInfo) { + if isByteSlice(typ) { c.convertByteArray(n) return nil } @@ -647,14 +646,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } emit.Int(c.prog.BinWriter, int64(ln)) emit.Opcode(c.prog.BinWriter, opcode.PACK) - return nil - } - - switch typ.(type) { - case *types.Struct: - c.convertStruct(n) - case *types.Map: - c.convertMap(n) } return nil @@ -693,7 +684,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { // example: // const x = 10 // x + 2 will results into 12 - tinfo := c.typeInfo.Types[n] + tinfo := c.typeAndValueOf(n) if tinfo.Value != nil { c.emitLoadConst(tinfo) return nil @@ -705,7 +696,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { switch { case n.Op == token.ADD: // VM has separate opcodes for number and string concatenation - if isStringType(tinfo.Type) { + if isString(tinfo.Type) { emit.Opcode(c.prog.BinWriter, opcode.CAT) } else { emit.Opcode(c.prog.BinWriter, opcode.ADD) @@ -716,7 +707,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { emit.Opcode(c.prog.BinWriter, op) case n.Op == token.NEQ: // VM has separate opcodes for number and string equality - if isStringType(c.typeInfo.Types[n.X].Type) { + if isString(c.typeOf(n.X)) { emit.Opcode(c.prog.BinWriter, opcode.NOTEQUAL) } else { emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL) @@ -800,8 +791,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { case *ast.SelectorExpr: switch t := n.X.(type) { case *ast.Ident: - typ := c.typeInfo.ObjectOf(t).Type().Underlying() - if strct, ok := typ.(*types.Struct); ok { + if strct, ok := c.typeOf(t).Underlying().(*types.Struct); ok { c.emitLoadVar(t.Name) // load the struct i := indexOfStruct(strct, n.Sel.Name) c.emitLoadField(i) // load the field @@ -852,14 +842,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { switch n.Index.(type) { case *ast.BasicLit: - t := c.typeInfo.Types[n.Index] - switch typ := t.Type.Underlying().(type) { - case *types.Basic: - c.convertBasicType(t, typ) - default: - c.prog.Err = fmt.Errorf("compiler can't use following type as an index: %T", typ) - return nil - } + c.emitLoadConst(c.typeAndValueOf(n.Index)) default: ast.Walk(c, n.Index) } @@ -1067,7 +1050,7 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 { } func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { - t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic) + t, ok := c.typeOf(expr).Underlying().(*types.Basic) if ok && t.Info()&types.IsNumeric != 0 { return opcode.NUMEQUAL } @@ -1080,18 +1063,18 @@ func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { func (c *codegen) getByteArray(expr ast.Expr) []byte { switch t := expr.(type) { case *ast.CompositeLit: - if !isByteArray(t, c.typeInfo) { + if !isByteSlice(c.typeOf(t.Type)) { return nil } buf := make([]byte, len(t.Elts)) for i := 0; i < len(t.Elts); i++ { - t := c.typeInfo.Types[t.Elts[i]] + t := c.typeAndValueOf(t.Elts[i]) val, _ := constant.Int64Val(t.Value) buf[i] = byte(val) } return buf case *ast.CallExpr: - if tv := c.typeInfo.Types[t.Args[0]]; tv.Value != nil { + if tv := c.typeAndValueOf(t.Args[0]); tv.Value != nil { val := constant.StringVal(tv.Value) return []byte(val) } @@ -1136,7 +1119,7 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) { case "append": arg := expr.Args[0] typ := c.typeInfo.Types[arg].Type - if isByteArrayType(typ) { + if isByteSlice(typ) { emit.Opcode(c.prog.BinWriter, opcode.CAT) } else { emit.Opcode(c.prog.BinWriter, opcode.OVER) @@ -1148,7 +1131,7 @@ func (c *codegen) convertBuiltin(expr *ast.CallExpr) { if isExprNil(arg) { emit.Opcode(c.prog.BinWriter, opcode.DROP) emit.Opcode(c.prog.BinWriter, opcode.THROW) - } else if isStringType(c.typeInfo.Types[arg].Type) { + } else if isString(c.typeInfo.Types[arg].Type) { ast.Walk(c, arg) emit.Syscall(c.prog.BinWriter, "Neo.Runtime.Log") emit.Opcode(c.prog.BinWriter, opcode.THROW) @@ -1216,7 +1199,7 @@ func transformArgs(fun ast.Expr, args []ast.Expr) []ast.Expr { func (c *codegen) convertByteArray(lit *ast.CompositeLit) { buf := make([]byte, len(lit.Elts)) for i := 0; i < len(lit.Elts); i++ { - t := c.typeInfo.Types[lit.Elts[i]] + t := c.typeAndValueOf(lit.Elts[i]) val, _ := constant.Int64Val(t.Value) buf[i] = byte(val) } @@ -1237,7 +1220,7 @@ func (c *codegen) convertMap(lit *ast.CompositeLit) { func (c *codegen) convertStruct(lit *ast.CompositeLit) { // Create a new structScope to initialize and store // the positions of its variables. - strct, ok := c.typeInfo.TypeOf(lit).Underlying().(*types.Struct) + strct, ok := c.typeOf(lit).Underlying().(*types.Struct) if !ok { c.prog.Err = fmt.Errorf("the given literal is not of type struct: %v", lit) return diff --git a/pkg/compiler/debug.go b/pkg/compiler/debug.go index 29af3f425..9f174b5cc 100644 --- a/pkg/compiler/debug.go +++ b/pkg/compiler/debug.go @@ -191,7 +191,11 @@ func (c *codegen) scReturnTypeFromScope(scope *funcScope) string { } func (c *codegen) scTypeFromExpr(typ ast.Expr) string { - switch t := c.typeInfo.Types[typ].Type.(type) { + t := c.typeOf(typ) + if c.typeOf(typ) == nil { + return "Any" + } + switch t := t.Underlying().(type) { case *types.Basic: info := t.Info() switch { @@ -209,7 +213,7 @@ func (c *codegen) scTypeFromExpr(typ ast.Expr) string { case *types.Struct: return "Struct" case *types.Slice: - if isByteArrayType(t) { + if isByte(t.Elem()) { return "ByteArray" } return "Array" diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go new file mode 100644 index 000000000..2011220fc --- /dev/null +++ b/pkg/compiler/types.go @@ -0,0 +1,44 @@ +package compiler + +import ( + "go/ast" + "go/types" +) + +func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue { + return c.typeInfo.Types[e] +} + +func (c *codegen) typeOf(e ast.Expr) types.Type { + return c.typeAndValueOf(e).Type +} + +func isBasicTypeOfKind(typ types.Type, ks ...types.BasicKind) bool { + if t, ok := typ.Underlying().(*types.Basic); ok { + k := t.Kind() + for i := range ks { + if k == ks[i] { + return true + } + } + } + return false +} + +func isByte(typ types.Type) bool { + return isBasicTypeOfKind(typ, types.Uint8, types.Int8) +} + +func isString(typ types.Type) bool { + return isBasicTypeOfKind(typ, types.String) +} + +func isCompoundSlice(typ types.Type) bool { + t, ok := typ.Underlying().(*types.Slice) + return ok && !isByte(t.Elem()) +} + +func isByteSlice(typ types.Type) bool { + t, ok := typ.Underlying().(*types.Slice) + return ok && isByte(t.Elem()) +}