compiler: support basic inlining

This commit is contained in:
Evgeniy Stratonikov 2021-02-04 15:41:00 +03:00
parent 1f238ce6fd
commit 1ae0d022dd
9 changed files with 258 additions and 4 deletions

View file

@ -304,3 +304,11 @@ func canConvert(s string) bool {
} }
return true return true
} }
// canInline returns true if function is to be inlined.
// Currently there is a static list of function which are inlined,
// this may change in future.
func canInline(s string) bool {
return isNativeHelpersPath(s) ||
strings.HasPrefix(s, "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline")
}

View file

@ -31,6 +31,8 @@ type codegen struct {
// Type information. // Type information.
typeInfo *types.Info typeInfo *types.Info
// pkgInfoInline is stack of type information for packages containing inline functions.
pkgInfoInline []*loader.PackageInfo
// A mapping of func identifiers with their scope. // A mapping of func identifiers with their scope.
funcs map[string]*funcScope funcs map[string]*funcScope
@ -406,6 +408,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
if sizeArg > 255 { if sizeArg > 255 {
c.prog.Err = errors.New("maximum of 255 local variables is allowed") c.prog.Err = errors.New("maximum of 255 local variables is allowed")
} }
sizeLoc = 255 // FIXME count locals including inline variables
if sizeLoc != 0 || sizeArg != 0 { if sizeLoc != 0 || sizeArg != 0 {
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)}) emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)})
} }
@ -623,7 +626,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.processDefers() c.processDefers()
c.saveSequencePoint(n) c.saveSequencePoint(n)
emit.Opcodes(c.prog.BinWriter, opcode.RET) if len(c.pkgInfoInline) == 0 {
emit.Opcodes(c.prog.BinWriter, opcode.RET)
}
return nil return nil
case *ast.IfStmt: case *ast.IfStmt:
@ -800,7 +805,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
switch fun := n.Fun.(type) { switch fun := n.Fun.(type) {
case *ast.Ident: case *ast.Ident:
f, ok = c.funcs[c.getIdentName("", fun.Name)] var pkgName string
if len(c.pkgInfoInline) != 0 {
pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path()
}
f, ok = c.funcs[c.getIdentName(pkgName, fun.Name)]
isBuiltin = isGoBuiltin(fun.Name) isBuiltin = isGoBuiltin(fun.Name)
if !ok && !isBuiltin { if !ok && !isBuiltin {
name = fun.Name name = fun.Name
@ -809,6 +819,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
if fun.Obj != nil && fun.Obj.Kind == ast.Var { if fun.Obj != nil && fun.Obj.Kind == ast.Var {
isFunc = true isFunc = true
} }
if ok && canInline(f.pkg.Path()) {
c.inlineCall(f, n)
return nil
}
case *ast.SelectorExpr: case *ast.SelectorExpr:
// If this is a method call we need to walk the AST to load the struct locally. // If this is a method call we need to walk the AST to load the struct locally.
// Otherwise this is a function call from a imported package and we can call it // Otherwise this is a function call from a imported package and we can call it
@ -824,6 +838,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
if ok { if ok {
f.selector = fun.X.(*ast.Ident) f.selector = fun.X.(*ast.Ident)
isBuiltin = isCustomBuiltin(f) isBuiltin = isCustomBuiltin(f)
if canInline(f.pkg.Path()) {
c.inlineCall(f, n)
return nil
}
} else { } else {
typ := c.typeOf(fun) typ := c.typeOf(fun)
if _, ok := typ.(*types.Signature); ok { if _, ok := typ.(*types.Signature); ok {
@ -1919,7 +1937,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
// of bytecode space. // of bytecode space.
name := c.getFuncNameFromDecl(pkg.Path(), n) name := c.getFuncNameFromDecl(pkg.Path(), n)
if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) && if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) &&
(!isInteropPath(pkg.Path()) || isNativeHelpersPath(pkg.Path())) { (!isInteropPath(pkg.Path()) && !canInline(pkg.Path())) {
c.convertFuncDecl(f, n, pkg) c.convertFuncDecl(f, n, pkg)
} }
} }
@ -1970,7 +1988,8 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
for _, decl := range f.Decls { for _, decl := range f.Decls {
switch n := decl.(type) { switch n := decl.(type) {
case *ast.FuncDecl: case *ast.FuncDecl:
c.newFunc(n) fs := c.newFunc(n)
fs.file = f
} }
} }
} }

View file

@ -22,6 +22,8 @@ type funcScope struct {
// Package where the function is defined. // Package where the function is defined.
pkg *types.Package pkg *types.Package
file *ast.File
// Program label of the scope // Program label of the scope
label uint16 label uint16

49
pkg/compiler/inline.go Normal file
View file

@ -0,0 +1,49 @@
package compiler
import (
"go/ast"
"go/types"
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
)
// inlineCall inlines call of n for function represented by f.
// Call `f(a,b)` for definition `func f(x,y int)` is translated to block:
// {
// x := a
// y := b
// <inline body of f directly>
// }
func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
pkg := c.buildInfo.program.Package(f.pkg.Path())
sig := c.typeOf(n.Fun).(*types.Signature)
// Arguments need to be walked with the current scope,
// while stored in the new.
oldScope := c.scope.vars.locals
c.scope.vars.newScope()
newScope := c.scope.vars.locals
defer c.scope.vars.dropScope()
for i := range n.Args {
c.scope.vars.locals = oldScope
ast.Walk(c, n.Args[i])
c.scope.vars.locals = newScope
name := sig.Params().At(i).Name()
c.scope.newLocal(name)
c.emitStoreVar("", name)
}
c.pkgInfoInline = append(c.pkgInfoInline, pkg)
oldMap := c.importMap
c.fillImportMap(f.file, pkg.Pkg)
ast.Inspect(f.decl, c.scope.analyzeVoidCalls)
ast.Walk(c, f.decl.Body)
if c.scope.voidCalls[n] {
for i := 0; i < f.decl.Type.Results.NumFields(); i++ {
emit.Opcodes(c.prog.BinWriter, opcode.DROP)
}
}
c.importMap = oldMap
c.pkgInfoInline = c.pkgInfoInline[:len(c.pkgInfoInline)-1]
}

125
pkg/compiler/inline_test.go Normal file
View file

@ -0,0 +1,125 @@
package compiler_test
import (
"fmt"
"math/big"
"strings"
"testing"
"github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/stretchr/testify/require"
)
func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int) {
v := vmAndCompile(t, src)
ctx := v.Context()
actualCall := 0
actualInitSlot := 0
for op, _, err := ctx.Next(); ; op, _, err = ctx.Next() {
require.NoError(t, err)
switch op {
case opcode.CALL, opcode.CALLL:
actualCall++
case opcode.INITSLOT:
actualInitSlot++
}
if ctx.IP() == ctx.LenInstr() {
break
}
}
require.Equal(t, expectedCall, actualCall)
require.Equal(t, expectedInitSlot, actualInitSlot)
}
func TestInline(t *testing.T) {
srcTmpl := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
// local alias
func sum(a, b int) int {
return 42
}
func Main() int {
%s
}`
t.Run("no return", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `inline.NoArgsNoReturn()
return 1`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(1))
})
t.Run("has return, dropped", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `inline.NoArgsReturn1()
return 2`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(2))
})
t.Run("drop twice", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `inline.DropInsideInline()
return 42`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(42))
})
t.Run("no args return 1", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.NoArgsReturn1()`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(1))
})
t.Run("sum", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(1, 2)`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(3))
})
t.Run("sum squared (nested inline)", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.SumSquared(1, 2)`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(9))
})
t.Run("inline function in inline function parameter", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), inline.Sum(3, 4))`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(9+3+4))
})
t.Run("global name clash", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.GetSumSameName()`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(42))
})
t.Run("local name clash", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), sum(3, 4))`)
checkCallCount(t, src, 1, 2)
eval(t, src, big.NewInt(51))
})
}
func TestInlineConversion(t *testing.T) {
src1 := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
var _ = inline.A
func Main() int {
a := 2
return inline.SumSquared(1, a)
}`
b1, err := compiler.Compile("foo.go", strings.NewReader(src1))
require.NoError(t, err)
src2 := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
var _ = inline.A
func Main() int {
a := 2
{
b := 1
c := a
{
bb := b
cc := c
return (bb + cc) * (b + c)
}
}
}`
b2, err := compiler.Compile("foo.go", strings.NewReader(src2))
require.NoError(t, err)
require.Equal(t, b2, b1)
}

7
pkg/compiler/testdata/inline/a/a.go vendored Normal file
View file

@ -0,0 +1,7 @@
package a
var A = 29
func GetA() int {
return A
}

7
pkg/compiler/testdata/inline/b/b.go vendored Normal file
View file

@ -0,0 +1,7 @@
package b
var A = 12
func GetA() int {
return A
}

32
pkg/compiler/testdata/inline/inline.go vendored Normal file
View file

@ -0,0 +1,32 @@
package inline
import (
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/a"
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/b"
)
func NoArgsNoReturn() {}
func NoArgsReturn1() int {
return 1
}
func Sum(a, b int) int {
return a + b
}
func sum(x, y int) int {
return x + y
}
func SumSquared(a, b int) int {
return sum(a, b) * (a + b)
}
var A = 1
func GetSumSameName() int {
return a.GetA() + b.GetA() + A
}
func DropInsideInline() int {
sum(1, 2)
sum(3, 4)
return 7
}

View file

@ -8,6 +8,11 @@ import (
) )
func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue { func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue {
for i := len(c.pkgInfoInline) - 1; i >= 0; i-- {
if tv, ok := c.pkgInfoInline[i].Types[e]; ok {
return tv
}
}
return c.typeInfo.Types[e] return c.typeInfo.Types[e]
} }