From 0a4ff9d3e4a9ab432fd5812eb18c98e03b5a7432 Mon Sep 17 00:00:00 2001 From: Evgeniy Stratonikov Date: Thu, 25 Feb 2021 15:12:16 +0300 Subject: [PATCH] compiler: allow to use inlined functions to init globals --- pkg/compiler/analysis.go | 13 +++++- pkg/compiler/func_scope.go | 79 ++++++++++++++++++++---------------- pkg/compiler/inline.go | 8 ++++ pkg/compiler/inline_test.go | 26 ++++++++++++ pkg/compiler/syscall_test.go | 20 +++++++++ 5 files changed, 111 insertions(+), 35 deletions(-) diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 1ffd979c7..4efbcc44c 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -46,13 +46,24 @@ func (c *codegen) traverseGlobals() (int, int, int) { var n, nConst int initLocals := -1 deployLocals := -1 - c.ForEachFile(func(f *ast.File, _ *types.Package) { + c.ForEachFile(func(f *ast.File, pkg *types.Package) { nv, nc := countGlobals(f) n += nv nConst += nc if initLocals == -1 || deployLocals == -1 || !hasDefer { ast.Inspect(f, func(node ast.Node) bool { switch n := node.(type) { + case *ast.GenDecl: + if n.Tok == token.VAR { + for i := range n.Specs { + for _, v := range n.Specs[i].(*ast.ValueSpec).Values { + num := c.countLocalsCall(v, pkg) + if num > initLocals { + initLocals = num + } + } + } + } case *ast.FuncDecl: if isInitFunc(n) { num, _ := c.countLocals(n) diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index eabc09c79..5e2534287 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -102,6 +102,50 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool { return true } +func (c *codegen) countLocalsCall(n ast.Expr, pkg *types.Package) int { + ce, ok := n.(*ast.CallExpr) + if !ok { + return -1 + } + + var size int + var name string + switch fun := ce.Fun.(type) { + case *ast.Ident: + var pkgName string + if pkg != nil { + pkgName = pkg.Path() + } + name = c.getIdentName(pkgName, fun.Name) + case *ast.SelectorExpr: + name, _ = c.getFuncNameFromSelector(fun) + default: + return 0 + } + if inner, ok := c.funcs[name]; ok && canInline(name) { + sig, ok := c.typeOf(ce.Fun).(*types.Signature) + if !ok { + info := c.buildInfo.program.Package(pkg.Path()) + sig = info.Types[ce.Fun].Type.(*types.Signature) + } + for i := range ce.Args { + switch ce.Args[i].(type) { + case *ast.Ident: + case *ast.BasicLit: + default: + size++ + } + } + // Variadic with direct var args. + if sig.Variadic() && !ce.Ellipsis.IsValid() { + size++ + } + innerSz, _ := c.countLocalsInline(inner.decl, inner.pkg, inner) + size += innerSz + } + return size +} + func (c *codegen) countLocals(decl *ast.FuncDecl) (int, bool) { return c.countLocalsInline(decl, nil, nil) } @@ -117,40 +161,7 @@ func (c *codegen) countLocalsInline(decl *ast.FuncDecl, pkg *types.Package, f *f ast.Inspect(decl, func(n ast.Node) bool { switch n := n.(type) { case *ast.CallExpr: - var name string - switch fun := n.Fun.(type) { - case *ast.Ident: - var pkgName string - if pkg != nil { - pkgName = pkg.Path() - } - name = c.getIdentName(pkgName, fun.Name) - case *ast.SelectorExpr: - name, _ = c.getFuncNameFromSelector(fun) - default: - return false - } - if inner, ok := c.funcs[name]; ok && canInline(name) { - sig, ok := c.typeOf(n.Fun).(*types.Signature) - if !ok { - info := c.buildInfo.program.Package(pkg.Path()) - sig = info.Types[n.Fun].Type.(*types.Signature) - } - for i := range n.Args { - switch n.Args[i].(type) { - case *ast.Ident: - case *ast.BasicLit: - default: - size++ - } - } - // Variadic with direct var args. - if sig.Variadic() && !n.Ellipsis.IsValid() { - size++ - } - innerSz, _ := c.countLocalsInline(inner.decl, inner.pkg, inner) - size += innerSz - } + size += c.countLocalsCall(n, pkg) return false case *ast.FuncType: num := n.Results.NumFields() diff --git a/pkg/compiler/inline.go b/pkg/compiler/inline.go index 92d89652b..4998d250e 100644 --- a/pkg/compiler/inline.go +++ b/pkg/compiler/inline.go @@ -19,6 +19,14 @@ func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) { pkg := c.buildInfo.program.Package(f.pkg.Path()) sig := c.typeOf(n.Fun).(*types.Signature) + // When inlined call is used during global initialization + // there is no func scope, thus this if. + if c.scope == nil { + c.scope = &funcScope{} + c.scope.vars.newScope() + defer func() { c.scope = nil }() + } + // Arguments need to be walked with the current scope, // while stored in the new. oldScope := c.scope.vars.locals diff --git a/pkg/compiler/inline_test.go b/pkg/compiler/inline_test.go index 9035a8d2a..9db10359c 100644 --- a/pkg/compiler/inline_test.go +++ b/pkg/compiler/inline_test.go @@ -115,6 +115,32 @@ func TestInline(t *testing.T) { }) } +func TestInlineGlobalVariable(t *testing.T) { + t.Run("simple", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline" + var a = inline.Sum(1, 2) + func Main() int { + return a + }` + eval(t, src, big.NewInt(3)) + }) + t.Run("complex", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline" + var a = inline.Sum(3, 4) + var b = inline.SumSquared(1, 2) + var c = a + b + func init() { + c-- + } + func Main() int { + return c + }` + eval(t, src, big.NewInt(15)) + }) +} + func TestInlineConversion(t *testing.T) { src1 := `package foo import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline" diff --git a/pkg/compiler/syscall_test.go b/pkg/compiler/syscall_test.go index 9db1af6c3..f5dbd8e35 100644 --- a/pkg/compiler/syscall_test.go +++ b/pkg/compiler/syscall_test.go @@ -4,10 +4,12 @@ import ( "math/big" "testing" + "github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames" istorage "github.com/nspcc-dev/neo-go/pkg/core/interop/storage" "github.com/nspcc-dev/neo-go/pkg/interop/contract" "github.com/nspcc-dev/neo-go/pkg/interop/storage" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" + "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -72,3 +74,21 @@ func TestNotify(t *testing.T) { assert.Equal(t, "single", s.events[1].Name) assert.Equal(t, []stackitem.Item{}, s.events[1].Item.Value()) } + +func TestSyscallInGlobalInit(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop/binary" + var a = binary.Base58Decode([]byte("5T")) + func Main() []byte { + return a + }` + v, s := vmAndCompileInterop(t, src) + s.interops[interopnames.ToID([]byte(interopnames.SystemBinaryBase58Decode))] = func(v *vm.VM) error { + s := v.Estack().Pop().Value().([]byte) + require.Equal(t, "5T", string(s)) + v.Estack().PushVal([]byte{1, 2}) + return nil + } + require.NoError(t, v.Run()) + require.Equal(t, []byte{1, 2}, v.Estack().Pop().Value()) +}