Merge pull request #1774 from nspcc-dev/fix/compiler

Allow to use inlined functions during global var init
This commit is contained in:
Roman Khimov 2021-02-25 16:16:21 +03:00 committed by GitHub
commit 549596bc1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 112 additions and 36 deletions

View file

@ -46,13 +46,24 @@ func (c *codegen) traverseGlobals() (int, int, int) {
var n, nConst int var n, nConst int
initLocals := -1 initLocals := -1
deployLocals := -1 deployLocals := -1
c.ForEachFile(func(f *ast.File, _ *types.Package) { c.ForEachFile(func(f *ast.File, pkg *types.Package) {
nv, nc := countGlobals(f) nv, nc := countGlobals(f)
n += nv n += nv
nConst += nc nConst += nc
if initLocals == -1 || deployLocals == -1 || !hasDefer { if initLocals == -1 || deployLocals == -1 || !hasDefer {
ast.Inspect(f, func(node ast.Node) bool { ast.Inspect(f, func(node ast.Node) bool {
switch n := node.(type) { switch n := node.(type) {
case *ast.GenDecl:
if n.Tok == token.VAR {
for i := range n.Specs {
for _, v := range n.Specs[i].(*ast.ValueSpec).Values {
num := c.countLocalsCall(v, pkg)
if num > initLocals {
initLocals = num
}
}
}
}
case *ast.FuncDecl: case *ast.FuncDecl:
if isInitFunc(n) { if isInitFunc(n) {
num, _ := c.countLocals(n) num, _ := c.countLocals(n)

View file

@ -102,6 +102,50 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
return true return true
} }
func (c *codegen) countLocalsCall(n ast.Expr, pkg *types.Package) int {
ce, ok := n.(*ast.CallExpr)
if !ok {
return -1
}
var size int
var name string
switch fun := ce.Fun.(type) {
case *ast.Ident:
var pkgName string
if pkg != nil {
pkgName = pkg.Path()
}
name = c.getIdentName(pkgName, fun.Name)
case *ast.SelectorExpr:
name, _ = c.getFuncNameFromSelector(fun)
default:
return 0
}
if inner, ok := c.funcs[name]; ok && canInline(name) {
sig, ok := c.typeOf(ce.Fun).(*types.Signature)
if !ok {
info := c.buildInfo.program.Package(pkg.Path())
sig = info.Types[ce.Fun].Type.(*types.Signature)
}
for i := range ce.Args {
switch ce.Args[i].(type) {
case *ast.Ident:
case *ast.BasicLit:
default:
size++
}
}
// Variadic with direct var args.
if sig.Variadic() && !ce.Ellipsis.IsValid() {
size++
}
innerSz, _ := c.countLocalsInline(inner.decl, inner.pkg, inner)
size += innerSz
}
return size
}
func (c *codegen) countLocals(decl *ast.FuncDecl) (int, bool) { func (c *codegen) countLocals(decl *ast.FuncDecl) (int, bool) {
return c.countLocalsInline(decl, nil, nil) return c.countLocalsInline(decl, nil, nil)
} }
@ -117,40 +161,7 @@ func (c *codegen) countLocalsInline(decl *ast.FuncDecl, pkg *types.Package, f *f
ast.Inspect(decl, func(n ast.Node) bool { ast.Inspect(decl, func(n ast.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *ast.CallExpr: case *ast.CallExpr:
var name string size += c.countLocalsCall(n, pkg)
switch fun := n.Fun.(type) {
case *ast.Ident:
var pkgName string
if pkg != nil {
pkgName = pkg.Path()
}
name = c.getIdentName(pkgName, fun.Name)
case *ast.SelectorExpr:
name, _ = c.getFuncNameFromSelector(fun)
default:
return false
}
if inner, ok := c.funcs[name]; ok && canInline(name) {
sig, ok := c.typeOf(n.Fun).(*types.Signature)
if !ok {
info := c.buildInfo.program.Package(pkg.Path())
sig = info.Types[n.Fun].Type.(*types.Signature)
}
for i := range n.Args {
switch n.Args[i].(type) {
case *ast.Ident:
case *ast.BasicLit:
default:
size++
}
}
// Variadic with direct var args.
if sig.Variadic() && !n.Ellipsis.IsValid() {
size++
}
innerSz, _ := c.countLocalsInline(inner.decl, inner.pkg, inner)
size += innerSz
}
return false return false
case *ast.FuncType: case *ast.FuncType:
num := n.Results.NumFields() num := n.Results.NumFields()

View file

@ -19,6 +19,14 @@ func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
pkg := c.buildInfo.program.Package(f.pkg.Path()) pkg := c.buildInfo.program.Package(f.pkg.Path())
sig := c.typeOf(n.Fun).(*types.Signature) sig := c.typeOf(n.Fun).(*types.Signature)
// When inlined call is used during global initialization
// there is no func scope, thus this if.
if c.scope == nil {
c.scope = &funcScope{}
c.scope.vars.newScope()
defer func() { c.scope = nil }()
}
// Arguments need to be walked with the current scope, // Arguments need to be walked with the current scope,
// while stored in the new. // while stored in the new.
oldScope := c.scope.vars.locals oldScope := c.scope.vars.locals

View file

@ -115,6 +115,32 @@ func TestInline(t *testing.T) {
}) })
} }
func TestInlineGlobalVariable(t *testing.T) {
t.Run("simple", func(t *testing.T) {
src := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
var a = inline.Sum(1, 2)
func Main() int {
return a
}`
eval(t, src, big.NewInt(3))
})
t.Run("complex", func(t *testing.T) {
src := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
var a = inline.Sum(3, 4)
var b = inline.SumSquared(1, 2)
var c = a + b
func init() {
c--
}
func Main() int {
return c
}`
eval(t, src, big.NewInt(15))
})
}
func TestInlineConversion(t *testing.T) { func TestInlineConversion(t *testing.T) {
src1 := `package foo src1 := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline" import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"

View file

@ -4,10 +4,12 @@ import (
"math/big" "math/big"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames"
istorage "github.com/nspcc-dev/neo-go/pkg/core/interop/storage" istorage "github.com/nspcc-dev/neo-go/pkg/core/interop/storage"
"github.com/nspcc-dev/neo-go/pkg/interop/contract" "github.com/nspcc-dev/neo-go/pkg/interop/contract"
"github.com/nspcc-dev/neo-go/pkg/interop/storage" "github.com/nspcc-dev/neo-go/pkg/interop/storage"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -72,3 +74,21 @@ func TestNotify(t *testing.T) {
assert.Equal(t, "single", s.events[1].Name) assert.Equal(t, "single", s.events[1].Name)
assert.Equal(t, []stackitem.Item{}, s.events[1].Item.Value()) assert.Equal(t, []stackitem.Item{}, s.events[1].Item.Value())
} }
func TestSyscallInGlobalInit(t *testing.T) {
src := `package foo
import "github.com/nspcc-dev/neo-go/pkg/interop/binary"
var a = binary.Base58Decode([]byte("5T"))
func Main() []byte {
return a
}`
v, s := vmAndCompileInterop(t, src)
s.interops[interopnames.ToID([]byte(interopnames.SystemBinaryBase58Decode))] = func(v *vm.VM) error {
s := v.Estack().Pop().Value().([]byte)
require.Equal(t, "5T", string(s))
v.Estack().PushVal([]byte{1, 2})
return nil
}
require.NoError(t, v.Run())
require.Equal(t, []byte{1, 2}, v.Estack().Pop().Value())
}

View file

@ -42,7 +42,7 @@ func Base58Encode(b []byte) string {
// Base58Decode decodes given base58 string represented as a byte slice into // Base58Decode decodes given base58 string represented as a byte slice into
// a new byte slice. It uses `System.Binary.Base58Decode` syscall. // a new byte slice. It uses `System.Binary.Base58Decode` syscall.
func Base58Decode(b []byte) []byte { func Base58Decode(b []byte) []byte {
return neogointernal.Syscall1("System.Binary.Base64Decode", b).([]byte) return neogointernal.Syscall1("System.Binary.Base58Decode", b).([]byte)
} }
// Itoa converts num in a given base to string. Base should be either 10 or 16. // Itoa converts num in a given base to string. Base should be either 10 or 16.