mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2025-01-26 19:17:24 +00:00
Merge pull request #2648 from nspcc-dev/restrict-multi-ret
compiler: adjust restrictions imposed on manifest functions
This commit is contained in:
commit
e23fc11da5
3 changed files with 177 additions and 18 deletions
|
@ -13,8 +13,13 @@ import (
|
||||||
"golang.org/x/tools/go/packages"
|
"golang.org/x/tools/go/packages"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrMissingExportedParamName is returned when exported contract method has unnamed parameter.
|
// Various exported functions usage errors.
|
||||||
var ErrMissingExportedParamName = errors.New("exported method is not allowed to have unnamed parameter")
|
var (
|
||||||
|
// ErrMissingExportedParamName is returned when exported contract method has unnamed parameter.
|
||||||
|
ErrMissingExportedParamName = errors.New("exported method is not allowed to have unnamed parameter")
|
||||||
|
// ErrInvalidExportedRetCount is returned when exported contract method has invalid return values count.
|
||||||
|
ErrInvalidExportedRetCount = errors.New("exported method is not allowed to have more than one return value")
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Go language builtin functions.
|
// Go language builtin functions.
|
||||||
|
@ -285,11 +290,12 @@ func (c *codegen) analyzeFuncUsage() funcUsage {
|
||||||
case *ast.FuncDecl:
|
case *ast.FuncDecl:
|
||||||
name := c.getFuncNameFromDecl(pkgPath, n)
|
name := c.getFuncNameFromDecl(pkgPath, n)
|
||||||
|
|
||||||
// exported functions are always assumed to be used
|
// exported functions and methods are always assumed to be used
|
||||||
if isMain && n.Name.IsExported() || isInitFunc(n) || isDeployFunc(n) {
|
if isMain && n.Name.IsExported() || isInitFunc(n) || isDeployFunc(n) {
|
||||||
diff[name] = true
|
diff[name] = true
|
||||||
}
|
}
|
||||||
if isMain && n.Name.IsExported() {
|
// exported functions are not allowed to have unnamed parameters or multiple return values
|
||||||
|
if isMain && n.Name.IsExported() && n.Recv == nil {
|
||||||
if n.Type.Params.List != nil {
|
if n.Type.Params.List != nil {
|
||||||
for i, param := range n.Type.Params.List {
|
for i, param := range n.Type.Params.List {
|
||||||
if param.Names == nil {
|
if param.Names == nil {
|
||||||
|
@ -304,6 +310,9 @@ func (c *codegen) analyzeFuncUsage() funcUsage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if retCnt := n.Type.Results.NumFields(); retCnt > 1 {
|
||||||
|
c.prog.Err = fmt.Errorf("%w: %s/%d return values", ErrInvalidExportedRetCount, n.Name, retCnt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
nodeCache[name] = declPair{n, c.importMap, pkgPath}
|
nodeCache[name] = declPair{n, c.importMap, pkgPath}
|
||||||
return false // will be processed in the next stage
|
return false // will be processed in the next stage
|
||||||
|
|
|
@ -2,6 +2,7 @@ package compiler_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -417,4 +418,157 @@ func TestUnnamedParameterCheck(t *testing.T) {
|
||||||
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
t.Run("method with unnamed params", func(t *testing.T) {
|
||||||
|
src := `
|
||||||
|
package testcase
|
||||||
|
type A int
|
||||||
|
func (rsv A) OnNEP17Payment(_ string, _ int, iface interface{}){}
|
||||||
|
`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err) // it's OK for exported method to have unnamed params as it won't be included into manifest
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnValuesCountCheck(t *testing.T) {
|
||||||
|
t.Run("void", func(t *testing.T) {
|
||||||
|
t.Run("exported", func(t *testing.T) {
|
||||||
|
t.Run("func", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
var a int
|
||||||
|
func Main() {
|
||||||
|
a = 5
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
t.Run("method", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
type A int
|
||||||
|
var a int
|
||||||
|
func (rcv A) Main() {
|
||||||
|
a = 5
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("unexported", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
var a int
|
||||||
|
func main() {
|
||||||
|
a = 5
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("single return", func(t *testing.T) {
|
||||||
|
t.Run("exported", func(t *testing.T) {
|
||||||
|
t.Run("func", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
var a int
|
||||||
|
func Main() int {
|
||||||
|
a = 5
|
||||||
|
return a
|
||||||
|
}`
|
||||||
|
eval(t, src, big.NewInt(5))
|
||||||
|
})
|
||||||
|
t.Run("method", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
type A int
|
||||||
|
var a int
|
||||||
|
func (rcv A) Main() int {
|
||||||
|
a = 5
|
||||||
|
return a
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("unexported", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
var a int
|
||||||
|
func main() int {
|
||||||
|
a = 5
|
||||||
|
return a
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("multiple unnamed return vals", func(t *testing.T) {
|
||||||
|
t.Run("exported", func(t *testing.T) {
|
||||||
|
t.Run("func", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
var a int
|
||||||
|
func Main() (int, int) {
|
||||||
|
a = 5
|
||||||
|
return a, a
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, compiler.ErrInvalidExportedRetCount)
|
||||||
|
})
|
||||||
|
t.Run("method", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
type A int
|
||||||
|
var a int
|
||||||
|
func (rcv A) Main() (int, int) {
|
||||||
|
a = 5
|
||||||
|
return a, a
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err) // OK for method to have multiple return values as it won't be included into manifest
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("unexported", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
var a int
|
||||||
|
func main() (int, int) {
|
||||||
|
a = 5
|
||||||
|
return a, a
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err) // OK for unexported function to have multiple return values as it won't be included into manifest
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("multiple named return vals", func(t *testing.T) {
|
||||||
|
t.Run("exported", func(t *testing.T) {
|
||||||
|
t.Run("func", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
var a int
|
||||||
|
func Main() (a int, b int) {
|
||||||
|
a = 5
|
||||||
|
b = 2
|
||||||
|
return
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, compiler.ErrInvalidExportedRetCount)
|
||||||
|
})
|
||||||
|
t.Run("method", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
type A int
|
||||||
|
var a int
|
||||||
|
func (rcv A) Main() (a int, b int) {
|
||||||
|
a = 5
|
||||||
|
b = 2
|
||||||
|
return
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err) // OK for method to have multiple return values as it won't be included into manifest
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("unexported", func(t *testing.T) {
|
||||||
|
src := `package testcase
|
||||||
|
var a int
|
||||||
|
func main() (a int, b int) {
|
||||||
|
a = 5
|
||||||
|
b = 2
|
||||||
|
return
|
||||||
|
}`
|
||||||
|
_, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil)
|
||||||
|
require.NoError(t, err) // OK for unexported function to have multiple return values as it won't be included into manifest
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,9 +4,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReturnInt64(t *testing.T) {
|
func TestReturnInt64(t *testing.T) {
|
||||||
|
@ -99,7 +96,11 @@ func TestSingleReturn(t *testing.T) {
|
||||||
|
|
||||||
func TestNamedReturn(t *testing.T) {
|
func TestNamedReturn(t *testing.T) {
|
||||||
src := `package foo
|
src := `package foo
|
||||||
func Main() (a int, b int) {
|
func Main() int {
|
||||||
|
a, b := f()
|
||||||
|
return a + b
|
||||||
|
}
|
||||||
|
func f() (a int, b int) {
|
||||||
a = 1
|
a = 1
|
||||||
b = 2
|
b = 2
|
||||||
c := 3
|
c := 3
|
||||||
|
@ -107,21 +108,16 @@ func TestNamedReturn(t *testing.T) {
|
||||||
return %s
|
return %s
|
||||||
}`
|
}`
|
||||||
|
|
||||||
runCase := func(ret string, result ...interface{}) func(t *testing.T) {
|
runCase := func(ret string, result *big.Int) func(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
src := fmt.Sprintf(src, ret)
|
src := fmt.Sprintf(src, ret)
|
||||||
v := vmAndCompile(t, src)
|
eval(t, src, result)
|
||||||
require.NoError(t, v.Run())
|
|
||||||
require.Equal(t, len(result), v.Estack().Len())
|
|
||||||
for i := range result {
|
|
||||||
assert.EqualValues(t, result[i], v.Estack().Pop().Value())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("NormalReturn", runCase("a, b", big.NewInt(1), big.NewInt(2)))
|
t.Run("NormalReturn", runCase("a, b", big.NewInt(3)))
|
||||||
t.Run("EmptyReturn", runCase("", big.NewInt(1), big.NewInt(2)))
|
t.Run("EmptyReturn", runCase("", big.NewInt(3)))
|
||||||
t.Run("AnotherVariable", runCase("b, c", big.NewInt(2), big.NewInt(3)))
|
t.Run("AnotherVariable", runCase("b, c", big.NewInt(5)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTypeAssertReturn(t *testing.T) {
|
func TestTypeAssertReturn(t *testing.T) {
|
||||||
|
|
Loading…
Reference in a new issue