diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 880ce84fa..396b7d4c0 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -304,3 +304,11 @@ func canConvert(s string) bool { } return true } + +// canInline returns true if function is to be inlined. +// Currently there is a static list of function which are inlined, +// this may change in future. +func canInline(s string) bool { + return isNativeHelpersPath(s) || + strings.HasPrefix(s, "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline") +} diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index ae98df43b..992b9c49b 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -31,6 +31,8 @@ type codegen struct { // Type information. typeInfo *types.Info + // pkgInfoInline is stack of type information for packages containing inline functions. + pkgInfoInline []*loader.PackageInfo // A mapping of func identifiers with their scope. funcs map[string]*funcScope @@ -406,6 +408,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types. if sizeArg > 255 { c.prog.Err = errors.New("maximum of 255 local variables is allowed") } + sizeLoc = 255 // FIXME count locals including inline variables if sizeLoc != 0 || sizeArg != 0 { emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)}) } @@ -623,7 +626,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { c.processDefers() c.saveSequencePoint(n) - emit.Opcodes(c.prog.BinWriter, opcode.RET) + if len(c.pkgInfoInline) == 0 { + emit.Opcodes(c.prog.BinWriter, opcode.RET) + } return nil case *ast.IfStmt: @@ -800,7 +805,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { switch fun := n.Fun.(type) { case *ast.Ident: - f, ok = c.funcs[c.getIdentName("", fun.Name)] + var pkgName string + if len(c.pkgInfoInline) != 0 { + pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path() + } + f, ok = c.funcs[c.getIdentName(pkgName, fun.Name)] + isBuiltin = isGoBuiltin(fun.Name) if !ok && !isBuiltin { name = fun.Name @@ -809,6 +819,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { if fun.Obj != nil && fun.Obj.Kind == ast.Var { isFunc = true } + if ok && canInline(f.pkg.Path()) { + c.inlineCall(f, n) + return nil + } case *ast.SelectorExpr: // If this is a method call we need to walk the AST to load the struct locally. // Otherwise this is a function call from a imported package and we can call it @@ -824,6 +838,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { if ok { f.selector = fun.X.(*ast.Ident) isBuiltin = isCustomBuiltin(f) + if canInline(f.pkg.Path()) { + c.inlineCall(f, n) + return nil + } } else { typ := c.typeOf(fun) if _, ok := typ.(*types.Signature); ok { @@ -1919,7 +1937,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { // of bytecode space. name := c.getFuncNameFromDecl(pkg.Path(), n) if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) && - (!isInteropPath(pkg.Path()) || isNativeHelpersPath(pkg.Path())) { + (!isInteropPath(pkg.Path()) && !canInline(pkg.Path())) { c.convertFuncDecl(f, n, pkg) } } @@ -1970,7 +1988,8 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) { for _, decl := range f.Decls { switch n := decl.(type) { case *ast.FuncDecl: - c.newFunc(n) + fs := c.newFunc(n) + fs.file = f } } } diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index f7923fe95..2734712bc 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -22,6 +22,8 @@ type funcScope struct { // Package where the function is defined. pkg *types.Package + file *ast.File + // Program label of the scope label uint16 diff --git a/pkg/compiler/inline.go b/pkg/compiler/inline.go new file mode 100644 index 000000000..9facbe792 --- /dev/null +++ b/pkg/compiler/inline.go @@ -0,0 +1,49 @@ +package compiler + +import ( + "go/ast" + "go/types" + + "github.com/nspcc-dev/neo-go/pkg/vm/emit" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" +) + +// inlineCall inlines call of n for function represented by f. +// Call `f(a,b)` for definition `func f(x,y int)` is translated to block: +// { +// x := a +// y := b +// +// } +func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) { + pkg := c.buildInfo.program.Package(f.pkg.Path()) + sig := c.typeOf(n.Fun).(*types.Signature) + + // Arguments need to be walked with the current scope, + // while stored in the new. + oldScope := c.scope.vars.locals + c.scope.vars.newScope() + newScope := c.scope.vars.locals + defer c.scope.vars.dropScope() + for i := range n.Args { + c.scope.vars.locals = oldScope + ast.Walk(c, n.Args[i]) + c.scope.vars.locals = newScope + name := sig.Params().At(i).Name() + c.scope.newLocal(name) + c.emitStoreVar("", name) + } + + c.pkgInfoInline = append(c.pkgInfoInline, pkg) + oldMap := c.importMap + c.fillImportMap(f.file, pkg.Pkg) + ast.Inspect(f.decl, c.scope.analyzeVoidCalls) + ast.Walk(c, f.decl.Body) + if c.scope.voidCalls[n] { + for i := 0; i < f.decl.Type.Results.NumFields(); i++ { + emit.Opcodes(c.prog.BinWriter, opcode.DROP) + } + } + c.importMap = oldMap + c.pkgInfoInline = c.pkgInfoInline[:len(c.pkgInfoInline)-1] +} diff --git a/pkg/compiler/inline_test.go b/pkg/compiler/inline_test.go new file mode 100644 index 000000000..05c6db8b3 --- /dev/null +++ b/pkg/compiler/inline_test.go @@ -0,0 +1,125 @@ +package compiler_test + +import ( + "fmt" + "math/big" + "strings" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/compiler" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/stretchr/testify/require" +) + +func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int) { + v := vmAndCompile(t, src) + ctx := v.Context() + actualCall := 0 + actualInitSlot := 0 + + for op, _, err := ctx.Next(); ; op, _, err = ctx.Next() { + require.NoError(t, err) + switch op { + case opcode.CALL, opcode.CALLL: + actualCall++ + case opcode.INITSLOT: + actualInitSlot++ + } + if ctx.IP() == ctx.LenInstr() { + break + } + } + require.Equal(t, expectedCall, actualCall) + require.Equal(t, expectedInitSlot, actualInitSlot) +} + +func TestInline(t *testing.T) { + srcTmpl := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline" + // local alias + func sum(a, b int) int { + return 42 + } + func Main() int { + %s + }` + t.Run("no return", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, `inline.NoArgsNoReturn() + return 1`) + checkCallCount(t, src, 0, 1) + eval(t, src, big.NewInt(1)) + }) + t.Run("has return, dropped", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, `inline.NoArgsReturn1() + return 2`) + checkCallCount(t, src, 0, 1) + eval(t, src, big.NewInt(2)) + }) + t.Run("drop twice", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, `inline.DropInsideInline() + return 42`) + checkCallCount(t, src, 0, 1) + eval(t, src, big.NewInt(42)) + }) + t.Run("no args return 1", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, `return inline.NoArgsReturn1()`) + checkCallCount(t, src, 0, 1) + eval(t, src, big.NewInt(1)) + }) + t.Run("sum", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, `return inline.Sum(1, 2)`) + checkCallCount(t, src, 0, 1) + eval(t, src, big.NewInt(3)) + }) + t.Run("sum squared (nested inline)", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, `return inline.SumSquared(1, 2)`) + checkCallCount(t, src, 0, 1) + eval(t, src, big.NewInt(9)) + }) + t.Run("inline function in inline function parameter", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), inline.Sum(3, 4))`) + checkCallCount(t, src, 0, 1) + eval(t, src, big.NewInt(9+3+4)) + }) + t.Run("global name clash", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, `return inline.GetSumSameName()`) + checkCallCount(t, src, 0, 1) + eval(t, src, big.NewInt(42)) + }) + t.Run("local name clash", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), sum(3, 4))`) + checkCallCount(t, src, 1, 2) + eval(t, src, big.NewInt(51)) + }) +} + +func TestInlineConversion(t *testing.T) { + src1 := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline" + var _ = inline.A + func Main() int { + a := 2 + return inline.SumSquared(1, a) + }` + b1, err := compiler.Compile("foo.go", strings.NewReader(src1)) + require.NoError(t, err) + + src2 := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline" + var _ = inline.A + func Main() int { + a := 2 + { + b := 1 + c := a + { + bb := b + cc := c + return (bb + cc) * (b + c) + } + } + }` + b2, err := compiler.Compile("foo.go", strings.NewReader(src2)) + require.NoError(t, err) + require.Equal(t, b2, b1) +} diff --git a/pkg/compiler/testdata/inline/a/a.go b/pkg/compiler/testdata/inline/a/a.go new file mode 100644 index 000000000..aa61e8142 --- /dev/null +++ b/pkg/compiler/testdata/inline/a/a.go @@ -0,0 +1,7 @@ +package a + +var A = 29 + +func GetA() int { + return A +} diff --git a/pkg/compiler/testdata/inline/b/b.go b/pkg/compiler/testdata/inline/b/b.go new file mode 100644 index 000000000..197ffb124 --- /dev/null +++ b/pkg/compiler/testdata/inline/b/b.go @@ -0,0 +1,7 @@ +package b + +var A = 12 + +func GetA() int { + return A +} diff --git a/pkg/compiler/testdata/inline/inline.go b/pkg/compiler/testdata/inline/inline.go new file mode 100644 index 000000000..d44ec461a --- /dev/null +++ b/pkg/compiler/testdata/inline/inline.go @@ -0,0 +1,32 @@ +package inline + +import ( + "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/a" + "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/b" +) + +func NoArgsNoReturn() {} +func NoArgsReturn1() int { + return 1 +} +func Sum(a, b int) int { + return a + b +} +func sum(x, y int) int { + return x + y +} +func SumSquared(a, b int) int { + return sum(a, b) * (a + b) +} + +var A = 1 + +func GetSumSameName() int { + return a.GetA() + b.GetA() + A +} + +func DropInsideInline() int { + sum(1, 2) + sum(3, 4) + return 7 +} diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go index 649c1f146..f331b7eaf 100644 --- a/pkg/compiler/types.go +++ b/pkg/compiler/types.go @@ -8,6 +8,11 @@ import ( ) func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue { + for i := len(c.pkgInfoInline) - 1; i >= 0; i-- { + if tv, ok := c.pkgInfoInline[i].Types[e]; ok { + return tv + } + } return c.typeInfo.Types[e] }