forked from TrueCloudLab/neoneo-go
compiler: support basic inlining
This commit is contained in:
parent
1f238ce6fd
commit
1ae0d022dd
9 changed files with 258 additions and 4 deletions
|
@ -304,3 +304,11 @@ func canConvert(s string) bool {
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// canInline returns true if function is to be inlined.
|
||||||
|
// Currently there is a static list of function which are inlined,
|
||||||
|
// this may change in future.
|
||||||
|
func canInline(s string) bool {
|
||||||
|
return isNativeHelpersPath(s) ||
|
||||||
|
strings.HasPrefix(s, "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline")
|
||||||
|
}
|
||||||
|
|
|
@ -31,6 +31,8 @@ type codegen struct {
|
||||||
|
|
||||||
// Type information.
|
// Type information.
|
||||||
typeInfo *types.Info
|
typeInfo *types.Info
|
||||||
|
// pkgInfoInline is stack of type information for packages containing inline functions.
|
||||||
|
pkgInfoInline []*loader.PackageInfo
|
||||||
|
|
||||||
// A mapping of func identifiers with their scope.
|
// A mapping of func identifiers with their scope.
|
||||||
funcs map[string]*funcScope
|
funcs map[string]*funcScope
|
||||||
|
@ -406,6 +408,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
|
||||||
if sizeArg > 255 {
|
if sizeArg > 255 {
|
||||||
c.prog.Err = errors.New("maximum of 255 local variables is allowed")
|
c.prog.Err = errors.New("maximum of 255 local variables is allowed")
|
||||||
}
|
}
|
||||||
|
sizeLoc = 255 // FIXME count locals including inline variables
|
||||||
if sizeLoc != 0 || sizeArg != 0 {
|
if sizeLoc != 0 || sizeArg != 0 {
|
||||||
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)})
|
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)})
|
||||||
}
|
}
|
||||||
|
@ -623,7 +626,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
c.processDefers()
|
c.processDefers()
|
||||||
|
|
||||||
c.saveSequencePoint(n)
|
c.saveSequencePoint(n)
|
||||||
|
if len(c.pkgInfoInline) == 0 {
|
||||||
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case *ast.IfStmt:
|
case *ast.IfStmt:
|
||||||
|
@ -800,7 +805,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
|
|
||||||
switch fun := n.Fun.(type) {
|
switch fun := n.Fun.(type) {
|
||||||
case *ast.Ident:
|
case *ast.Ident:
|
||||||
f, ok = c.funcs[c.getIdentName("", fun.Name)]
|
var pkgName string
|
||||||
|
if len(c.pkgInfoInline) != 0 {
|
||||||
|
pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path()
|
||||||
|
}
|
||||||
|
f, ok = c.funcs[c.getIdentName(pkgName, fun.Name)]
|
||||||
|
|
||||||
isBuiltin = isGoBuiltin(fun.Name)
|
isBuiltin = isGoBuiltin(fun.Name)
|
||||||
if !ok && !isBuiltin {
|
if !ok && !isBuiltin {
|
||||||
name = fun.Name
|
name = fun.Name
|
||||||
|
@ -809,6 +819,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
if fun.Obj != nil && fun.Obj.Kind == ast.Var {
|
if fun.Obj != nil && fun.Obj.Kind == ast.Var {
|
||||||
isFunc = true
|
isFunc = true
|
||||||
}
|
}
|
||||||
|
if ok && canInline(f.pkg.Path()) {
|
||||||
|
c.inlineCall(f, n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
case *ast.SelectorExpr:
|
case *ast.SelectorExpr:
|
||||||
// If this is a method call we need to walk the AST to load the struct locally.
|
// If this is a method call we need to walk the AST to load the struct locally.
|
||||||
// Otherwise this is a function call from a imported package and we can call it
|
// Otherwise this is a function call from a imported package and we can call it
|
||||||
|
@ -824,6 +838,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
if ok {
|
if ok {
|
||||||
f.selector = fun.X.(*ast.Ident)
|
f.selector = fun.X.(*ast.Ident)
|
||||||
isBuiltin = isCustomBuiltin(f)
|
isBuiltin = isCustomBuiltin(f)
|
||||||
|
if canInline(f.pkg.Path()) {
|
||||||
|
c.inlineCall(f, n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
typ := c.typeOf(fun)
|
typ := c.typeOf(fun)
|
||||||
if _, ok := typ.(*types.Signature); ok {
|
if _, ok := typ.(*types.Signature); ok {
|
||||||
|
@ -1919,7 +1937,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
||||||
// of bytecode space.
|
// of bytecode space.
|
||||||
name := c.getFuncNameFromDecl(pkg.Path(), n)
|
name := c.getFuncNameFromDecl(pkg.Path(), n)
|
||||||
if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) &&
|
if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) &&
|
||||||
(!isInteropPath(pkg.Path()) || isNativeHelpersPath(pkg.Path())) {
|
(!isInteropPath(pkg.Path()) && !canInline(pkg.Path())) {
|
||||||
c.convertFuncDecl(f, n, pkg)
|
c.convertFuncDecl(f, n, pkg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1970,7 +1988,8 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
|
||||||
for _, decl := range f.Decls {
|
for _, decl := range f.Decls {
|
||||||
switch n := decl.(type) {
|
switch n := decl.(type) {
|
||||||
case *ast.FuncDecl:
|
case *ast.FuncDecl:
|
||||||
c.newFunc(n)
|
fs := c.newFunc(n)
|
||||||
|
fs.file = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,8 @@ type funcScope struct {
|
||||||
// Package where the function is defined.
|
// Package where the function is defined.
|
||||||
pkg *types.Package
|
pkg *types.Package
|
||||||
|
|
||||||
|
file *ast.File
|
||||||
|
|
||||||
// Program label of the scope
|
// Program label of the scope
|
||||||
label uint16
|
label uint16
|
||||||
|
|
||||||
|
|
49
pkg/compiler/inline.go
Normal file
49
pkg/compiler/inline.go
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
package compiler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go/ast"
|
||||||
|
"go/types"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// inlineCall inlines call of n for function represented by f.
|
||||||
|
// Call `f(a,b)` for definition `func f(x,y int)` is translated to block:
|
||||||
|
// {
|
||||||
|
// x := a
|
||||||
|
// y := b
|
||||||
|
// <inline body of f directly>
|
||||||
|
// }
|
||||||
|
func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
|
||||||
|
pkg := c.buildInfo.program.Package(f.pkg.Path())
|
||||||
|
sig := c.typeOf(n.Fun).(*types.Signature)
|
||||||
|
|
||||||
|
// Arguments need to be walked with the current scope,
|
||||||
|
// while stored in the new.
|
||||||
|
oldScope := c.scope.vars.locals
|
||||||
|
c.scope.vars.newScope()
|
||||||
|
newScope := c.scope.vars.locals
|
||||||
|
defer c.scope.vars.dropScope()
|
||||||
|
for i := range n.Args {
|
||||||
|
c.scope.vars.locals = oldScope
|
||||||
|
ast.Walk(c, n.Args[i])
|
||||||
|
c.scope.vars.locals = newScope
|
||||||
|
name := sig.Params().At(i).Name()
|
||||||
|
c.scope.newLocal(name)
|
||||||
|
c.emitStoreVar("", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.pkgInfoInline = append(c.pkgInfoInline, pkg)
|
||||||
|
oldMap := c.importMap
|
||||||
|
c.fillImportMap(f.file, pkg.Pkg)
|
||||||
|
ast.Inspect(f.decl, c.scope.analyzeVoidCalls)
|
||||||
|
ast.Walk(c, f.decl.Body)
|
||||||
|
if c.scope.voidCalls[n] {
|
||||||
|
for i := 0; i < f.decl.Type.Results.NumFields(); i++ {
|
||||||
|
emit.Opcodes(c.prog.BinWriter, opcode.DROP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.importMap = oldMap
|
||||||
|
c.pkgInfoInline = c.pkgInfoInline[:len(c.pkgInfoInline)-1]
|
||||||
|
}
|
125
pkg/compiler/inline_test.go
Normal file
125
pkg/compiler/inline_test.go
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
package compiler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/compiler"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int) {
|
||||||
|
v := vmAndCompile(t, src)
|
||||||
|
ctx := v.Context()
|
||||||
|
actualCall := 0
|
||||||
|
actualInitSlot := 0
|
||||||
|
|
||||||
|
for op, _, err := ctx.Next(); ; op, _, err = ctx.Next() {
|
||||||
|
require.NoError(t, err)
|
||||||
|
switch op {
|
||||||
|
case opcode.CALL, opcode.CALLL:
|
||||||
|
actualCall++
|
||||||
|
case opcode.INITSLOT:
|
||||||
|
actualInitSlot++
|
||||||
|
}
|
||||||
|
if ctx.IP() == ctx.LenInstr() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, expectedCall, actualCall)
|
||||||
|
require.Equal(t, expectedInitSlot, actualInitSlot)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInline(t *testing.T) {
|
||||||
|
srcTmpl := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||||
|
// local alias
|
||||||
|
func sum(a, b int) int {
|
||||||
|
return 42
|
||||||
|
}
|
||||||
|
func Main() int {
|
||||||
|
%s
|
||||||
|
}`
|
||||||
|
t.Run("no return", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `inline.NoArgsNoReturn()
|
||||||
|
return 1`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(1))
|
||||||
|
})
|
||||||
|
t.Run("has return, dropped", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `inline.NoArgsReturn1()
|
||||||
|
return 2`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(2))
|
||||||
|
})
|
||||||
|
t.Run("drop twice", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `inline.DropInsideInline()
|
||||||
|
return 42`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(42))
|
||||||
|
})
|
||||||
|
t.Run("no args return 1", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.NoArgsReturn1()`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(1))
|
||||||
|
})
|
||||||
|
t.Run("sum", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.Sum(1, 2)`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(3))
|
||||||
|
})
|
||||||
|
t.Run("sum squared (nested inline)", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.SumSquared(1, 2)`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(9))
|
||||||
|
})
|
||||||
|
t.Run("inline function in inline function parameter", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), inline.Sum(3, 4))`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(9+3+4))
|
||||||
|
})
|
||||||
|
t.Run("global name clash", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.GetSumSameName()`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(42))
|
||||||
|
})
|
||||||
|
t.Run("local name clash", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), sum(3, 4))`)
|
||||||
|
checkCallCount(t, src, 1, 2)
|
||||||
|
eval(t, src, big.NewInt(51))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineConversion(t *testing.T) {
|
||||||
|
src1 := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||||
|
var _ = inline.A
|
||||||
|
func Main() int {
|
||||||
|
a := 2
|
||||||
|
return inline.SumSquared(1, a)
|
||||||
|
}`
|
||||||
|
b1, err := compiler.Compile("foo.go", strings.NewReader(src1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
src2 := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||||
|
var _ = inline.A
|
||||||
|
func Main() int {
|
||||||
|
a := 2
|
||||||
|
{
|
||||||
|
b := 1
|
||||||
|
c := a
|
||||||
|
{
|
||||||
|
bb := b
|
||||||
|
cc := c
|
||||||
|
return (bb + cc) * (b + c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
b2, err := compiler.Compile("foo.go", strings.NewReader(src2))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, b2, b1)
|
||||||
|
}
|
7
pkg/compiler/testdata/inline/a/a.go
vendored
Normal file
7
pkg/compiler/testdata/inline/a/a.go
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package a
|
||||||
|
|
||||||
|
var A = 29
|
||||||
|
|
||||||
|
func GetA() int {
|
||||||
|
return A
|
||||||
|
}
|
7
pkg/compiler/testdata/inline/b/b.go
vendored
Normal file
7
pkg/compiler/testdata/inline/b/b.go
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package b
|
||||||
|
|
||||||
|
var A = 12
|
||||||
|
|
||||||
|
func GetA() int {
|
||||||
|
return A
|
||||||
|
}
|
32
pkg/compiler/testdata/inline/inline.go
vendored
Normal file
32
pkg/compiler/testdata/inline/inline.go
vendored
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package inline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/a"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/b"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NoArgsNoReturn() {}
|
||||||
|
func NoArgsReturn1() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
func Sum(a, b int) int {
|
||||||
|
return a + b
|
||||||
|
}
|
||||||
|
func sum(x, y int) int {
|
||||||
|
return x + y
|
||||||
|
}
|
||||||
|
func SumSquared(a, b int) int {
|
||||||
|
return sum(a, b) * (a + b)
|
||||||
|
}
|
||||||
|
|
||||||
|
var A = 1
|
||||||
|
|
||||||
|
func GetSumSameName() int {
|
||||||
|
return a.GetA() + b.GetA() + A
|
||||||
|
}
|
||||||
|
|
||||||
|
func DropInsideInline() int {
|
||||||
|
sum(1, 2)
|
||||||
|
sum(3, 4)
|
||||||
|
return 7
|
||||||
|
}
|
|
@ -8,6 +8,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue {
|
func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue {
|
||||||
|
for i := len(c.pkgInfoInline) - 1; i >= 0; i-- {
|
||||||
|
if tv, ok := c.pkgInfoInline[i].Types[e]; ok {
|
||||||
|
return tv
|
||||||
|
}
|
||||||
|
}
|
||||||
return c.typeInfo.Types[e]
|
return c.typeInfo.Types[e]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue