diff --git a/cli/smartcontract/smart_contract.go b/cli/smartcontract/smart_contract.go index 4e0728354..ab8016d01 100644 --- a/cli/smartcontract/smart_contract.go +++ b/cli/smartcontract/smart_contract.go @@ -1,7 +1,6 @@ package smartcontract import ( - "bytes" "encoding/hex" "encoding/json" "errors" @@ -563,12 +562,10 @@ func inspect(ctx *cli.Context) error { if len(in) == 0 { return cli.NewExitError(errNoInput, 1) } - b, err := ioutil.ReadFile(in) - if err != nil { - return cli.NewExitError(err, 1) - } + var b []byte + var err error if compile { - b, err = compiler.Compile(in, bytes.NewReader(b)) + b, err = compiler.Compile(in, nil) if err != nil { return cli.NewExitError(fmt.Errorf("failed to compile: %w", err), 1) } diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 740c990bd..7c131319f 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -1,8 +1,8 @@ package compiler import ( - "bytes" "encoding/json" + "errors" "fmt" "go/ast" "go/parser" @@ -10,6 +10,7 @@ import ( "io" "io/ioutil" "os" + "path" "strings" "github.com/nspcc-dev/neo-go/pkg/smartcontract" @@ -85,11 +86,32 @@ func (c *codegen) fillImportMap(f *ast.File, pkg *types.Package) { func getBuildInfo(name string, src interface{}) (*buildInfo, error) { conf := loader.Config{ParserMode: parser.ParseComments} - f, err := conf.ParseFile(name, src) - if err != nil { - return nil, err + if src != nil { + f, err := conf.ParseFile(name, src) + if err != nil { + return nil, err + } + conf.CreateFromFiles("", f) + } else { + var names []string + if strings.HasSuffix(name, ".go") { + names = append(names, name) + } else { + ds, err := ioutil.ReadDir(name) + if err != nil { + return nil, fmt.Errorf("'%s' is neither Go source nor a directory", name) + } + for i := range ds { + if !ds[i].IsDir() && strings.HasSuffix(ds[i].Name(), ".go") { + names = append(names, path.Join(name, ds[i].Name())) + } + } + } + if len(names) == 0 { + return nil, errors.New("no files provided") + } + conf.CreateFromFilenames("", names...) } - conf.CreateFromFiles("", f) prog, err := conf.Load() if err != nil { @@ -97,12 +119,14 @@ func getBuildInfo(name string, src interface{}) (*buildInfo, error) { } return &buildInfo{ - initialPackage: f.Name.Name, + initialPackage: prog.InitialPackages()[0].Pkg.Name(), program: prog, }, nil } // Compile compiles a Go program into bytecode that can run on the NEO virtual machine. +// If `r != nil`, `name` is interpreted as a filename, and `r` as file contents. +// Otherwise `name` is either file name or name of the directory containing source files. func Compile(name string, r io.Reader) ([]byte, error) { buf, _, err := CompileWithDebugInfo(name, r) if err != nil { @@ -123,21 +147,18 @@ func CompileWithDebugInfo(name string, r io.Reader) ([]byte, *DebugInfo, error) // CompileAndSave will compile and save the file to disk in the NEF format. func CompileAndSave(src string, o *Options) ([]byte, error) { - if !strings.HasSuffix(src, ".go") { - return nil, fmt.Errorf("%s is not a Go file", src) - } o.Outfile = strings.TrimSuffix(o.Outfile, fmt.Sprintf(".%s", fileExt)) if len(o.Outfile) == 0 { - o.Outfile = strings.TrimSuffix(src, ".go") + if strings.HasSuffix(src, ".go") { + o.Outfile = strings.TrimSuffix(src, ".go") + } else { + o.Outfile = "out" + } } if len(o.Ext) == 0 { o.Ext = fileExt } - b, err := ioutil.ReadFile(src) - if err != nil { - return nil, err - } - b, di, err := CompileWithDebugInfo(src, bytes.NewReader(b)) + b, di, err := CompileWithDebugInfo(src, nil) if err != nil { return nil, fmt.Errorf("error while trying to compile smart contract file: %w", err) } diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index 2e114cbef..797e0a16a 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -24,6 +24,20 @@ func TestCompiler(t *testing.T) { // CompileAndSave use config.Version for proper .nef generation. config.Version = "0.90.0-test" testCases := []compilerTestCase{ + { + name: "TestCompileDirectory", + function: func(t *testing.T) { + const multiMainDir = "testdata/multi" + _, di, err := compiler.CompileWithDebugInfo(multiMainDir, nil) + require.NoError(t, err) + m := map[string]bool{} + for i := range di.Methods { + m[di.Methods[i].Name.Name] = true + } + require.Contains(t, m, "Func1") + require.Contains(t, m, "Func2") + }, + }, { name: "TestCompile", function: func(t *testing.T) { @@ -73,10 +87,6 @@ func filterFilename(infos []os.FileInfo) string { } func compileFile(src string) error { - file, err := os.Open(src) - if err != nil { - return err - } - _, err = compiler.Compile("foo.go", file) + _, err := compiler.Compile(src, nil) return err } diff --git a/pkg/compiler/testdata/multi/file1.go b/pkg/compiler/testdata/multi/file1.go index c51714e74..9cbe0cd1d 100644 --- a/pkg/compiler/testdata/multi/file1.go +++ b/pkg/compiler/testdata/multi/file1.go @@ -3,3 +3,7 @@ package multi var SomeVar12 = 12 const SomeConst = 42 + +func Func1() bool { + return true +} diff --git a/pkg/compiler/testdata/multi/file2.go b/pkg/compiler/testdata/multi/file2.go index 2ee034599..a96ba78d9 100644 --- a/pkg/compiler/testdata/multi/file2.go +++ b/pkg/compiler/testdata/multi/file2.go @@ -5,3 +5,7 @@ var SomeVar30 = 30 func Sum() int { return SomeVar12 + SomeVar30 } + +func Func2() bool { + return false +} diff --git a/pkg/vm/cli/cli.go b/pkg/vm/cli/cli.go index cc1011cb8..a5fc1b060 100644 --- a/pkg/vm/cli/cli.go +++ b/pkg/vm/cli/cli.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "errors" "fmt" - "io/ioutil" "math/big" "os" "strconv" @@ -306,12 +305,7 @@ func handleLoadGo(c *ishell.Context) { c.Err(errors.New("missing parameter ")) return } - fb, err := ioutil.ReadFile(c.Args[0]) - if err != nil { - c.Err(err) - return - } - b, err := compiler.Compile(c.Args[0], bytes.NewReader(fb)) + b, err := compiler.Compile(c.Args[0], nil) if err != nil { c.Err(err) return