diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index f1d84e211..ada41de1a 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -18,6 +18,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "golang.org/x/tools/go/loader" ) // The identifier of the entry function. Default set to Main. @@ -1240,23 +1241,12 @@ func (c *codegen) newFunc(decl *ast.FuncDecl) *funcScope { return f } -// CodeGen compiles the program to bytecode. -func CodeGen(info *buildInfo) ([]byte, error) { - pkg := info.program.Package(info.initialPackage) - c := &codegen{ - buildInfo: info, - prog: io.NewBufBinWriter(), - l: []int{}, - funcs: map[string]*funcScope{}, - labels: map[labelWithType]uint16{}, - typeInfo: &pkg.Info, - } - +func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { // Resolve the entrypoint of the program. main, mainFile := resolveEntryPoint(mainIdent, pkg) if main == nil { c.prog.Err = fmt.Errorf("could not find func main. Did you forget to declare it? ") - return []byte{}, c.prog.Err + return c.prog.Err } funUsage := analyzeFuncUsage(info.program.AllPackages) @@ -1297,9 +1287,29 @@ func CodeGen(info *buildInfo) ([]byte, error) { } } - if c.prog.Err != nil { - return nil, c.prog.Err + return c.prog.Err +} + +func newCodegen(info *buildInfo, pkg *loader.PackageInfo) *codegen { + return &codegen{ + buildInfo: info, + prog: io.NewBufBinWriter(), + l: []int{}, + funcs: map[string]*funcScope{}, + labels: map[labelWithType]uint16{}, + typeInfo: &pkg.Info, } +} + +// CodeGen compiles the program to bytecode. +func CodeGen(info *buildInfo) ([]byte, error) { + pkg := info.program.Package(info.initialPackage) + c := newCodegen(info, pkg) + + if err := c.compile(info, pkg); err != nil { + return nil, err + } + buf := c.prog.Bytes() if err := c.writeJumps(buf); err != nil { return nil, err diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 91fc072c6..40cae382d 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -31,10 +31,9 @@ type buildInfo struct { program *loader.Program } -// Compile compiles a Go program into bytecode that can run on the NEO virtual machine. -func Compile(r io.Reader) ([]byte, error) { +func getBuildInfo(src interface{}) (*buildInfo, error) { conf := loader.Config{ParserMode: parser.ParseComments} - f, err := conf.ParseFile("", r) + f, err := conf.ParseFile("", src) if err != nil { return nil, err } @@ -45,9 +44,17 @@ func Compile(r io.Reader) ([]byte, error) { return nil, err } - ctx := &buildInfo{ + return &buildInfo{ initialPackage: f.Name.Name, program: prog, + }, nil +} + +// Compile compiles a Go program into bytecode that can run on the NEO virtual machine. +func Compile(r io.Reader) ([]byte, error) { + ctx, err := getBuildInfo(r) + if err != nil { + return nil, err } buf, err := CodeGen(ctx) diff --git a/pkg/compiler/debug_test.go b/pkg/compiler/debug_test.go index f98493077..25b2b157f 100644 --- a/pkg/compiler/debug_test.go +++ b/pkg/compiler/debug_test.go @@ -4,8 +4,66 @@ import ( "testing" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func TestCodeGen_DebugInfo(t *testing.T) { + src := `package foo +func Main(op string) bool { + res := methodInt(op) + _ = methodString() + _ = methodByteArray() + _ = methodArray() + _ = methodStruct() + return res == 42 +} + +func methodInt(a string) int { + if a == "get42" { + return 42 + } + return 3 +} +func methodString() string { return "" } +func methodByteArray() []byte { return nil } +func methodArray() []bool { return nil } +func methodStruct() struct{} { return struct{}{} } +` + + info, err := getBuildInfo(src) + require.NoError(t, err) + + pkg := info.program.Package(info.initialPackage) + c := newCodegen(info, pkg) + require.NoError(t, c.compile(info, pkg)) + + buf := c.prog.Bytes() + d := c.emitDebugInfo() + require.NotNil(t, d) + + t.Run("return types", func(t *testing.T) { + returnTypes := map[string]string{ + "methodInt": "Integer", + "methodString": "String", "methodByteArray": "ByteArray", + "methodArray": "Array", "methodStruct": "Struct", + "Main": "Boolean", + } + for i := range d.Methods { + name := d.Methods[i].Name.Name + assert.Equal(t, returnTypes[name], d.Methods[i].ReturnType) + } + }) + + // basic check that last instruction of every method is indeed RET + for i := range d.Methods { + index := d.Methods[i].Range.End + require.True(t, int(index) < len(buf)) + require.EqualValues(t, opcode.RET, buf[index]) + } +} + func TestDebugInfo_MarshalJSON(t *testing.T) { d := &DebugInfo{ EntryPoint: "main",