compiler: support basic inlining
This commit is contained in:
parent
1f238ce6fd
commit
1ae0d022dd
9 changed files with 258 additions and 4 deletions
|
@ -304,3 +304,11 @@ func canConvert(s string) bool {
|
|||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// canInline returns true if function is to be inlined.
|
||||
// Currently there is a static list of function which are inlined,
|
||||
// this may change in future.
|
||||
func canInline(s string) bool {
|
||||
return isNativeHelpersPath(s) ||
|
||||
strings.HasPrefix(s, "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline")
|
||||
}
|
||||
|
|
|
@ -31,6 +31,8 @@ type codegen struct {
|
|||
|
||||
// Type information.
|
||||
typeInfo *types.Info
|
||||
// pkgInfoInline is stack of type information for packages containing inline functions.
|
||||
pkgInfoInline []*loader.PackageInfo
|
||||
|
||||
// A mapping of func identifiers with their scope.
|
||||
funcs map[string]*funcScope
|
||||
|
@ -406,6 +408,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
|
|||
if sizeArg > 255 {
|
||||
c.prog.Err = errors.New("maximum of 255 local variables is allowed")
|
||||
}
|
||||
sizeLoc = 255 // FIXME count locals including inline variables
|
||||
if sizeLoc != 0 || sizeArg != 0 {
|
||||
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)})
|
||||
}
|
||||
|
@ -623,7 +626,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
|||
c.processDefers()
|
||||
|
||||
c.saveSequencePoint(n)
|
||||
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
||||
if len(c.pkgInfoInline) == 0 {
|
||||
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
||||
}
|
||||
return nil
|
||||
|
||||
case *ast.IfStmt:
|
||||
|
@ -800,7 +805,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
|||
|
||||
switch fun := n.Fun.(type) {
|
||||
case *ast.Ident:
|
||||
f, ok = c.funcs[c.getIdentName("", fun.Name)]
|
||||
var pkgName string
|
||||
if len(c.pkgInfoInline) != 0 {
|
||||
pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path()
|
||||
}
|
||||
f, ok = c.funcs[c.getIdentName(pkgName, fun.Name)]
|
||||
|
||||
isBuiltin = isGoBuiltin(fun.Name)
|
||||
if !ok && !isBuiltin {
|
||||
name = fun.Name
|
||||
|
@ -809,6 +819,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
|||
if fun.Obj != nil && fun.Obj.Kind == ast.Var {
|
||||
isFunc = true
|
||||
}
|
||||
if ok && canInline(f.pkg.Path()) {
|
||||
c.inlineCall(f, n)
|
||||
return nil
|
||||
}
|
||||
case *ast.SelectorExpr:
|
||||
// If this is a method call we need to walk the AST to load the struct locally.
|
||||
// Otherwise this is a function call from a imported package and we can call it
|
||||
|
@ -824,6 +838,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
|||
if ok {
|
||||
f.selector = fun.X.(*ast.Ident)
|
||||
isBuiltin = isCustomBuiltin(f)
|
||||
if canInline(f.pkg.Path()) {
|
||||
c.inlineCall(f, n)
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
typ := c.typeOf(fun)
|
||||
if _, ok := typ.(*types.Signature); ok {
|
||||
|
@ -1919,7 +1937,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
|||
// of bytecode space.
|
||||
name := c.getFuncNameFromDecl(pkg.Path(), n)
|
||||
if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) &&
|
||||
(!isInteropPath(pkg.Path()) || isNativeHelpersPath(pkg.Path())) {
|
||||
(!isInteropPath(pkg.Path()) && !canInline(pkg.Path())) {
|
||||
c.convertFuncDecl(f, n, pkg)
|
||||
}
|
||||
}
|
||||
|
@ -1970,7 +1988,8 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
|
|||
for _, decl := range f.Decls {
|
||||
switch n := decl.(type) {
|
||||
case *ast.FuncDecl:
|
||||
c.newFunc(n)
|
||||
fs := c.newFunc(n)
|
||||
fs.file = f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,8 @@ type funcScope struct {
|
|||
// Package where the function is defined.
|
||||
pkg *types.Package
|
||||
|
||||
file *ast.File
|
||||
|
||||
// Program label of the scope
|
||||
label uint16
|
||||
|
||||
|
|
49
pkg/compiler/inline.go
Normal file
49
pkg/compiler/inline.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package compiler
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/types"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
)
|
||||
|
||||
// inlineCall inlines call of n for function represented by f.
|
||||
// Call `f(a,b)` for definition `func f(x,y int)` is translated to block:
|
||||
// {
|
||||
// x := a
|
||||
// y := b
|
||||
// <inline body of f directly>
|
||||
// }
|
||||
func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
|
||||
pkg := c.buildInfo.program.Package(f.pkg.Path())
|
||||
sig := c.typeOf(n.Fun).(*types.Signature)
|
||||
|
||||
// Arguments need to be walked with the current scope,
|
||||
// while stored in the new.
|
||||
oldScope := c.scope.vars.locals
|
||||
c.scope.vars.newScope()
|
||||
newScope := c.scope.vars.locals
|
||||
defer c.scope.vars.dropScope()
|
||||
for i := range n.Args {
|
||||
c.scope.vars.locals = oldScope
|
||||
ast.Walk(c, n.Args[i])
|
||||
c.scope.vars.locals = newScope
|
||||
name := sig.Params().At(i).Name()
|
||||
c.scope.newLocal(name)
|
||||
c.emitStoreVar("", name)
|
||||
}
|
||||
|
||||
c.pkgInfoInline = append(c.pkgInfoInline, pkg)
|
||||
oldMap := c.importMap
|
||||
c.fillImportMap(f.file, pkg.Pkg)
|
||||
ast.Inspect(f.decl, c.scope.analyzeVoidCalls)
|
||||
ast.Walk(c, f.decl.Body)
|
||||
if c.scope.voidCalls[n] {
|
||||
for i := 0; i < f.decl.Type.Results.NumFields(); i++ {
|
||||
emit.Opcodes(c.prog.BinWriter, opcode.DROP)
|
||||
}
|
||||
}
|
||||
c.importMap = oldMap
|
||||
c.pkgInfoInline = c.pkgInfoInline[:len(c.pkgInfoInline)-1]
|
||||
}
|
125
pkg/compiler/inline_test.go
Normal file
125
pkg/compiler/inline_test.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package compiler_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/compiler"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int) {
|
||||
v := vmAndCompile(t, src)
|
||||
ctx := v.Context()
|
||||
actualCall := 0
|
||||
actualInitSlot := 0
|
||||
|
||||
for op, _, err := ctx.Next(); ; op, _, err = ctx.Next() {
|
||||
require.NoError(t, err)
|
||||
switch op {
|
||||
case opcode.CALL, opcode.CALLL:
|
||||
actualCall++
|
||||
case opcode.INITSLOT:
|
||||
actualInitSlot++
|
||||
}
|
||||
if ctx.IP() == ctx.LenInstr() {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Equal(t, expectedCall, actualCall)
|
||||
require.Equal(t, expectedInitSlot, actualInitSlot)
|
||||
}
|
||||
|
||||
func TestInline(t *testing.T) {
|
||||
srcTmpl := `package foo
|
||||
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||
// local alias
|
||||
func sum(a, b int) int {
|
||||
return 42
|
||||
}
|
||||
func Main() int {
|
||||
%s
|
||||
}`
|
||||
t.Run("no return", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `inline.NoArgsNoReturn()
|
||||
return 1`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
eval(t, src, big.NewInt(1))
|
||||
})
|
||||
t.Run("has return, dropped", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `inline.NoArgsReturn1()
|
||||
return 2`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
eval(t, src, big.NewInt(2))
|
||||
})
|
||||
t.Run("drop twice", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `inline.DropInsideInline()
|
||||
return 42`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
eval(t, src, big.NewInt(42))
|
||||
})
|
||||
t.Run("no args return 1", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.NoArgsReturn1()`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
eval(t, src, big.NewInt(1))
|
||||
})
|
||||
t.Run("sum", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.Sum(1, 2)`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
eval(t, src, big.NewInt(3))
|
||||
})
|
||||
t.Run("sum squared (nested inline)", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.SumSquared(1, 2)`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
eval(t, src, big.NewInt(9))
|
||||
})
|
||||
t.Run("inline function in inline function parameter", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), inline.Sum(3, 4))`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
eval(t, src, big.NewInt(9+3+4))
|
||||
})
|
||||
t.Run("global name clash", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.GetSumSameName()`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
eval(t, src, big.NewInt(42))
|
||||
})
|
||||
t.Run("local name clash", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), sum(3, 4))`)
|
||||
checkCallCount(t, src, 1, 2)
|
||||
eval(t, src, big.NewInt(51))
|
||||
})
|
||||
}
|
||||
|
||||
func TestInlineConversion(t *testing.T) {
|
||||
src1 := `package foo
|
||||
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||
var _ = inline.A
|
||||
func Main() int {
|
||||
a := 2
|
||||
return inline.SumSquared(1, a)
|
||||
}`
|
||||
b1, err := compiler.Compile("foo.go", strings.NewReader(src1))
|
||||
require.NoError(t, err)
|
||||
|
||||
src2 := `package foo
|
||||
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||
var _ = inline.A
|
||||
func Main() int {
|
||||
a := 2
|
||||
{
|
||||
b := 1
|
||||
c := a
|
||||
{
|
||||
bb := b
|
||||
cc := c
|
||||
return (bb + cc) * (b + c)
|
||||
}
|
||||
}
|
||||
}`
|
||||
b2, err := compiler.Compile("foo.go", strings.NewReader(src2))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, b2, b1)
|
||||
}
|
7
pkg/compiler/testdata/inline/a/a.go
vendored
Normal file
7
pkg/compiler/testdata/inline/a/a.go
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
package a
|
||||
|
||||
var A = 29
|
||||
|
||||
func GetA() int {
|
||||
return A
|
||||
}
|
7
pkg/compiler/testdata/inline/b/b.go
vendored
Normal file
7
pkg/compiler/testdata/inline/b/b.go
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
package b
|
||||
|
||||
var A = 12
|
||||
|
||||
func GetA() int {
|
||||
return A
|
||||
}
|
32
pkg/compiler/testdata/inline/inline.go
vendored
Normal file
32
pkg/compiler/testdata/inline/inline.go
vendored
Normal file
|
@ -0,0 +1,32 @@
|
|||
package inline
|
||||
|
||||
import (
|
||||
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/a"
|
||||
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/b"
|
||||
)
|
||||
|
||||
func NoArgsNoReturn() {}
|
||||
func NoArgsReturn1() int {
|
||||
return 1
|
||||
}
|
||||
func Sum(a, b int) int {
|
||||
return a + b
|
||||
}
|
||||
func sum(x, y int) int {
|
||||
return x + y
|
||||
}
|
||||
func SumSquared(a, b int) int {
|
||||
return sum(a, b) * (a + b)
|
||||
}
|
||||
|
||||
var A = 1
|
||||
|
||||
func GetSumSameName() int {
|
||||
return a.GetA() + b.GetA() + A
|
||||
}
|
||||
|
||||
func DropInsideInline() int {
|
||||
sum(1, 2)
|
||||
sum(3, 4)
|
||||
return 7
|
||||
}
|
|
@ -8,6 +8,11 @@ import (
|
|||
)
|
||||
|
||||
func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue {
|
||||
for i := len(c.pkgInfoInline) - 1; i >= 0; i-- {
|
||||
if tv, ok := c.pkgInfoInline[i].Types[e]; ok {
|
||||
return tv
|
||||
}
|
||||
}
|
||||
return c.typeInfo.Types[e]
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue