From 94f6a9ee61f12f2118f28e8a3ed02db252e0341e Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Thu, 14 Jul 2022 15:34:57 +0300 Subject: [PATCH 1/2] compiler: disallow unnamed parameters for exported methods --- pkg/compiler/analysis.go | 23 ++++++++++++ pkg/compiler/codegen.go | 3 ++ pkg/compiler/compiler_test.go | 70 +++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+) diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index e0eaee510..317838ece 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -2,6 +2,7 @@ package compiler import ( "errors" + "fmt" "go/ast" "go/token" "go/types" @@ -12,6 +13,9 @@ import ( "golang.org/x/tools/go/packages" ) +// ErrMissingExportedParamName is returned when exported contract method has unnamed parameter. +var ErrMissingExportedParamName = errors.New("exported method is not allowed to have unnamed parameter") + var ( // Go language builtin functions. goBuiltins = []string{"len", "append", "panic", "make", "copy", "recover", "delete"} @@ -284,12 +288,31 @@ func (c *codegen) analyzeFuncUsage() funcUsage { if isMain && n.Name.IsExported() || isInitFunc(n) || isDeployFunc(n) { diff[name] = true } + if isMain && n.Name.IsExported() { + if n.Type.Params.List != nil { + for i, param := range n.Type.Params.List { + if param.Names == nil { + c.prog.Err = fmt.Errorf("%w: %s", ErrMissingExportedParamName, n.Name) + return false // Program is invalid. + } + for _, name := range param.Names { + if name == nil || name.Name == "_" { + c.prog.Err = fmt.Errorf("%w: %s/%d", ErrMissingExportedParamName, n.Name, i) + return false // Program is invalid. + } + } + } + } + } nodeCache[name] = declPair{n, c.importMap, pkgPath} return false // will be processed in the next stage } return true }) }) + if c.prog.Err != nil { + return nil + } usage := funcUsage{} for len(diff) != 0 { diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 33cec05f3..a554dae80 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -2106,6 +2106,9 @@ func (c *codegen) compile(info *buildInfo, pkg *packages.Package) error { c.analyzePkgOrder() c.fillDocumentInfo() funUsage := c.analyzeFuncUsage() + if c.prog.Err != nil { + return c.prog.Err + } // Bring all imported functions into scope. c.ForEachFile(c.resolveFuncDecls) diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index 7dcd59786..2e04251e4 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -343,3 +343,73 @@ func TestInvokedContractsPermissons(t *testing.T) { }) }) } + +func TestUnnamedParameterCheck(t *testing.T) { + t.Run("single argument", func(t *testing.T) { + src := ` + package testcase + func Main(_ int) int { + x := 10 + return x + } + ` + _, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil) + require.Error(t, err) + require.ErrorIs(t, err, compiler.ErrMissingExportedParamName) + }) + t.Run("several arguments", func(t *testing.T) { + src := ` + package testcase + func Main(a int, b string, _ int) int { + x := 10 + return x + } + ` + _, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil) + require.Error(t, err) + require.ErrorIs(t, err, compiler.ErrMissingExportedParamName) + }) + t.Run("interface", func(t *testing.T) { + src := ` + package testcase + func OnNEP17Payment(h string, i int, _ interface{}){} + ` + _, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil) + require.Error(t, err) + require.ErrorIs(t, err, compiler.ErrMissingExportedParamName) + }) + t.Run("a set of unnamed params", func(t *testing.T) { + src := ` + package testcase + func OnNEP17Payment(_ string, _ int, _ interface{}){} + ` + _, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil) + require.Error(t, err) + require.ErrorIs(t, err, compiler.ErrMissingExportedParamName) + }) + t.Run("mixed named and unnamed params", func(t *testing.T) { + src := ` + package testcase + func OnNEP17Payment(s0, _, s2 string){} + ` + _, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil) + require.Error(t, err) + require.ErrorIs(t, err, compiler.ErrMissingExportedParamName) + }) + t.Run("empty args", func(t *testing.T) { + src := ` + package testcase + func OnNEP17Payment(){} + ` + _, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil) + require.NoError(t, err) + }) + t.Run("good", func(t *testing.T) { + src := ` + package testcase + func OnNEP17Payment(s string, i int, iface interface{}){} + ` + _, _, err := compiler.CompileWithOptions("test.go", strings.NewReader(src), nil) + require.NoError(t, err) + }) +} From 725e8779a138ce2acd2bf6d74b6aa7f23d7853f7 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Thu, 14 Jul 2022 15:36:21 +0300 Subject: [PATCH 2/2] compiler: always ensure manifest passes base check --- pkg/compiler/compiler.go | 4 ++++ pkg/compiler/compiler_test.go | 19 ++++++++++++------- pkg/smartcontract/manifest/group.go | 11 +++++++---- pkg/smartcontract/manifest/group_test.go | 3 +++ pkg/smartcontract/manifest/manifest.go | 1 + 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 949230e44..2bfa66d0d 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -331,6 +331,10 @@ func CreateManifest(di *DebugInfo, o *Options) (*manifest.Manifest, error) { return m, fmt.Errorf("method %s is marked as safe but missing from manifest", name) } } + err = m.IsValid(util.Uint160{}) // Check as much as possible without hash. + if err != nil { + return m, fmt.Errorf("manifest is invalid: %w", err) + } if !o.NoStandardCheck { if err := standard.CheckABI(m, o.ContractSupportedStandards...); err != nil { return m, err diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index 2e04251e4..c9e326865 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -94,7 +94,7 @@ func TestOnPayableChecks(t *testing.T) { compileAndCheck := func(t *testing.T, src string) error { _, di, err := compiler.CompileWithOptions("payable.go", strings.NewReader(src), nil) require.NoError(t, err) - _, err = compiler.CreateManifest(di, &compiler.Options{}) + _, err = compiler.CreateManifest(di, &compiler.Options{Name: "payable"}) return err } @@ -132,10 +132,10 @@ func TestSafeMethodWarnings(t *testing.T) { &compiler.Options{Name: "eventTest"}) require.NoError(t, err) - _, err = compiler.CreateManifest(di, &compiler.Options{SafeMethods: []string{"main"}}) + _, err = compiler.CreateManifest(di, &compiler.Options{SafeMethods: []string{"main"}, Name: "eventTest"}) require.NoError(t, err) - _, err = compiler.CreateManifest(di, &compiler.Options{SafeMethods: []string{"main", "mississippi"}}) + _, err = compiler.CreateManifest(di, &compiler.Options{SafeMethods: []string{"main", "mississippi"}, Name: "eventTest"}) require.Error(t, err) } @@ -148,17 +148,18 @@ func TestEventWarnings(t *testing.T) { require.NoError(t, err) t.Run("event it missing from config", func(t *testing.T) { - _, err = compiler.CreateManifest(di, &compiler.Options{}) + _, err = compiler.CreateManifest(di, &compiler.Options{Name: "payable"}) require.Error(t, err) t.Run("suppress", func(t *testing.T) { - _, err = compiler.CreateManifest(di, &compiler.Options{NoEventsCheck: true}) + _, err = compiler.CreateManifest(di, &compiler.Options{NoEventsCheck: true, Name: "payable"}) require.NoError(t, err) }) }) t.Run("wrong parameter number", func(t *testing.T) { _, err = compiler.CreateManifest(di, &compiler.Options{ ContractEvents: []manifest.Event{{Name: "Event"}}, + Name: "payable", }) require.Error(t, err) }) @@ -168,6 +169,7 @@ func TestEventWarnings(t *testing.T) { Name: "Event", Parameters: []manifest.Parameter{manifest.NewParameter("number", smartcontract.StringType)}, }}, + Name: "payable", }) require.Error(t, err) }) @@ -177,6 +179,7 @@ func TestEventWarnings(t *testing.T) { Name: "Event", Parameters: []manifest.Parameter{manifest.NewParameter("number", smartcontract.IntegerType)}, }}, + Name: "payable", }) require.NoError(t, err) }) @@ -191,7 +194,7 @@ func TestEventWarnings(t *testing.T) { _, di, err := compiler.CompileWithOptions("eventTest.go", strings.NewReader(src), &compiler.Options{Name: "eventTest"}) require.NoError(t, err) - _, err = compiler.CreateManifest(di, &compiler.Options{NoEventsCheck: true}) + _, err = compiler.CreateManifest(di, &compiler.Options{NoEventsCheck: true, Name: "eventTest"}) require.NoError(t, err) }) t.Run("used", func(t *testing.T) { @@ -206,11 +209,12 @@ func TestEventWarnings(t *testing.T) { strings.NewReader(src), &compiler.Options{Name: "eventTest"}) require.NoError(t, err) - _, err = compiler.CreateManifest(di, &compiler.Options{}) + _, err = compiler.CreateManifest(di, &compiler.Options{Name: "eventTest"}) require.Error(t, err) _, err = compiler.CreateManifest(di, &compiler.Options{ ContractEvents: []manifest.Event{{Name: "Event"}}, + Name: "eventTest", }) require.NoError(t, err) }) @@ -243,6 +247,7 @@ func TestInvokedContractsPermissons(t *testing.T) { o := &compiler.Options{ NoPermissionsCheck: disable, Permissions: ps, + Name: "test", } _, err := compiler.CreateManifest(di, o) diff --git a/pkg/smartcontract/manifest/group.go b/pkg/smartcontract/manifest/group.go index 11c59fdff..c66126e80 100644 --- a/pkg/smartcontract/manifest/group.go +++ b/pkg/smartcontract/manifest/group.go @@ -38,11 +38,14 @@ func (g *Group) IsValid(h util.Uint160) error { } // AreValid checks for groups correctness and uniqueness. +// If the contract hash is empty, then hash-related checks are omitted. func (g Groups) AreValid(h util.Uint160) error { - for i := range g { - err := g[i].IsValid(h) - if err != nil { - return err + if !h.Equals(util.Uint160{}) { + for i := range g { + err := g[i].IsValid(h) + if err != nil { + return err + } } } if len(g) < 2 { diff --git a/pkg/smartcontract/manifest/group_test.go b/pkg/smartcontract/manifest/group_test.go index 6c50832c8..2390da7f5 100644 --- a/pkg/smartcontract/manifest/group_test.go +++ b/pkg/smartcontract/manifest/group_test.go @@ -40,6 +40,9 @@ func TestGroupsAreValid(t *testing.T) { gps = Groups{gcorrect, gcorrect} require.Error(t, gps.AreValid(h)) + + gps = Groups{gincorrect} + require.NoError(t, gps.AreValid(util.Uint160{})) // empty hash. } func TestGroupsContains(t *testing.T) { diff --git a/pkg/smartcontract/manifest/manifest.go b/pkg/smartcontract/manifest/manifest.go index b5264f88a..e0b968772 100644 --- a/pkg/smartcontract/manifest/manifest.go +++ b/pkg/smartcontract/manifest/manifest.go @@ -82,6 +82,7 @@ func (m *Manifest) CanCall(hash util.Uint160, toCall *Manifest, method string) b // IsValid checks manifest internal consistency and correctness, one of the // checks is for group signature correctness, contract hash is passed for it. +// If hash is empty, then hash-related checks are omitted. func (m *Manifest) IsValid(hash util.Uint160) error { var err error