diff --git a/cli/smartcontract/generate.go b/cli/smartcontract/generate.go new file mode 100644 index 000000000..c3e627866 --- /dev/null +++ b/cli/smartcontract/generate.go @@ -0,0 +1,80 @@ +package smartcontract + +import ( + "fmt" + "io/ioutil" + "os" + + "github.com/nspcc-dev/neo-go/pkg/smartcontract/binding" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/urfave/cli" + "gopkg.in/yaml.v2" +) + +var generateWrapperCmd = cli.Command{ + Name: "generate-wrapper", + Usage: "generate wrapper to use in other contracts", + UsageText: "neo-go contract generate-wrapper --manifest manifest.json --out file.go", + Description: ``, + Action: contractGenerateWrapper, + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "config, c", + Usage: "Configuration file to use", + }, + cli.StringFlag{ + Name: "manifest, m", + Usage: "Read contract manifest (*.manifest.json) file", + }, + cli.StringFlag{ + Name: "out, o", + Usage: "Output of the compiled contract", + }, + cli.StringFlag{ + Name: "hash", + Usage: "Smart-contract hash", + }, + }, +} + +// contractGenerateWrapper deploys contract. +func contractGenerateWrapper(ctx *cli.Context) error { + m, _, err := readManifest(ctx.String("manifest")) + if err != nil { + return cli.NewExitError(fmt.Errorf("can't read contract manifest: %w", err), 1) + } + + cfg := binding.NewConfig() + if cfgPath := ctx.String("config"); cfgPath != "" { + bs, err := ioutil.ReadFile(cfgPath) + if err != nil { + return cli.NewExitError(fmt.Errorf("can't read config file: %w", err), 1) + } + err = yaml.Unmarshal(bs, &cfg) + if err != nil { + return cli.NewExitError(fmt.Errorf("can't parse config file: %w", err), 1) + } + } + + cfg.Manifest = m + + h, err := util.Uint160DecodeStringLE(ctx.String("hash")) + if err != nil { + return cli.NewExitError(fmt.Errorf("invalid contract hash: %w", err), 1) + } + cfg.Hash = h + + f, err := os.Create(ctx.String("out")) + if err != nil { + return cli.NewExitError(fmt.Errorf("can't create output file: %w", err), 1) + } + defer f.Close() + + cfg.Output = f + + err = binding.Generate(cfg) + if err != nil { + return cli.NewExitError(fmt.Errorf("error during generation: %w", err), 1) + } + return nil +} diff --git a/cli/smartcontract/generate_test.go b/cli/smartcontract/generate_test.go new file mode 100644 index 000000000..83d455209 --- /dev/null +++ b/cli/smartcontract/generate_test.go @@ -0,0 +1,330 @@ +package smartcontract + +import ( + "encoding/json" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" + "github.com/urfave/cli" +) + +func TestGenerate(t *testing.T) { + m := manifest.NewManifest("MyContract") + m.ABI.Methods = append(m.ABI.Methods, + manifest.Method{ + Name: manifest.MethodDeploy, + ReturnType: smartcontract.VoidType, + }, + manifest.Method{ + Name: "sum", + Parameters: []manifest.Parameter{ + manifest.NewParameter("first", smartcontract.IntegerType), + manifest.NewParameter("second", smartcontract.IntegerType), + }, + ReturnType: smartcontract.IntegerType, + }, + manifest.Method{ + Name: "sum", // overloaded method + Parameters: []manifest.Parameter{ + manifest.NewParameter("first", smartcontract.IntegerType), + manifest.NewParameter("second", smartcontract.IntegerType), + manifest.NewParameter("third", smartcontract.IntegerType), + }, + ReturnType: smartcontract.IntegerType, + }, + manifest.Method{ + Name: "sum3", + Parameters: []manifest.Parameter{}, + ReturnType: smartcontract.IntegerType, + Safe: true, + }, + manifest.Method{ + Name: "justExecute", + Parameters: []manifest.Parameter{ + manifest.NewParameter("arr", smartcontract.ArrayType), + }, + ReturnType: smartcontract.VoidType, + }, + manifest.Method{ + Name: "getPublicKey", + Parameters: nil, + ReturnType: smartcontract.PublicKeyType, + }, + manifest.Method{ + Name: "otherTypes", + Parameters: []manifest.Parameter{ + manifest.NewParameter("ctr", smartcontract.Hash160Type), + manifest.NewParameter("tx", smartcontract.Hash256Type), + manifest.NewParameter("sig", smartcontract.SignatureType), + manifest.NewParameter("data", smartcontract.AnyType), + }, + ReturnType: smartcontract.BoolType, + }, + manifest.Method{ + Name: "emptyName", + Parameters: []manifest.Parameter{ + manifest.NewParameter("", smartcontract.MapType), + }, + ReturnType: smartcontract.AnyType, + }, + manifest.Method{ + Name: "searchStorage", + Parameters: []manifest.Parameter{ + manifest.NewParameter("ctx", smartcontract.InteropInterfaceType), + }, + ReturnType: smartcontract.InteropInterfaceType, + }, + manifest.Method{ + Name: "getFromMap", + Parameters: []manifest.Parameter{ + manifest.NewParameter("intMap", smartcontract.MapType), + manifest.NewParameter("indices", smartcontract.ArrayType), + }, + ReturnType: smartcontract.ArrayType, + }, + manifest.Method{ + Name: "doSomething", + Parameters: []manifest.Parameter{ + manifest.NewParameter("bytes", smartcontract.ByteArrayType), + manifest.NewParameter("str", smartcontract.StringType), + }, + ReturnType: smartcontract.InteropInterfaceType, + }, + manifest.Method{ + Name: "getBlockWrapper", + Parameters: []manifest.Parameter{}, + ReturnType: smartcontract.InteropInterfaceType, + }, + manifest.Method{ + Name: "myFunc", + Parameters: []manifest.Parameter{ + manifest.NewParameter("in", smartcontract.MapType), + }, + ReturnType: smartcontract.ArrayType, + }) + + manifestFile := filepath.Join(t.TempDir(), "manifest.json") + outFile := filepath.Join(t.TempDir(), "out.go") + + rawManifest, err := json.Marshal(m) + require.NoError(t, err) + require.NoError(t, ioutil.WriteFile(manifestFile, rawManifest, os.ModePerm)) + + h := util.Uint160{ + 0x04, 0x08, 0x15, 0x16, 0x23, 0x42, 0x43, 0x44, 0x00, 0x01, + 0xCA, 0xFE, 0xBA, 0xBE, 0xDE, 0xAD, 0xBE, 0xEF, 0x03, 0x04, + } + app := cli.NewApp() + app.Commands = []cli.Command{generateWrapperCmd} + + rawCfg := `package: wrapper +hash: ` + h.StringLE() + ` +overrides: + searchStorage.ctx: storage.Context + searchStorage: iterator.Iterator + getFromMap.intMap: "map[string]int" + getFromMap.indices: "[]string" + getFromMap: "[]int" + getBlockWrapper: ledger.Block + myFunc.in: "map[int]github.com/heyitsme/mycontract.Input" + myFunc: "[]github.com/heyitsme/mycontract.Output" +callflags: + doSomething: ReadStates +` + cfgPath := filepath.Join(t.TempDir(), "binding.yml") + require.NoError(t, ioutil.WriteFile(cfgPath, []byte(rawCfg), os.ModePerm)) + + require.NoError(t, app.Run([]string{"", "generate-wrapper", + "--manifest", manifestFile, + "--config", cfgPath, + "--out", outFile, + "--hash", h.StringLE(), + })) + + const expected = `// Package wrapper contains wrappers for MyContract contract. +package wrapper + +import ( + "github.com/heyitsme/mycontract" + "github.com/nspcc-dev/neo-go/pkg/interop" + "github.com/nspcc-dev/neo-go/pkg/interop/contract" + "github.com/nspcc-dev/neo-go/pkg/interop/iterator" + "github.com/nspcc-dev/neo-go/pkg/interop/native/ledger" + "github.com/nspcc-dev/neo-go/pkg/interop/neogointernal" + "github.com/nspcc-dev/neo-go/pkg/interop/storage" +) + +// Hash contains contract hash in big-endian form. +const Hash = "\x04\x08\x15\x16\x23\x42\x43\x44\x00\x01\xca\xfe\xba\xbe\xde\xad\xbe\xef\x03\x04" + +// Sum invokes ` + "`sum`" + ` method of contract. +func Sum(first int, second int) int { + return neogointernal.CallWithToken(Hash, "sum", int(contract.All), first, second).(int) +} + +// Sum_3 invokes ` + "`sum`" + ` method of contract. +func Sum_3(first int, second int, third int) int { + return neogointernal.CallWithToken(Hash, "sum", int(contract.All), first, second, third).(int) +} + +// Sum3 invokes ` + "`sum3`" + ` method of contract. +func Sum3() int { + return neogointernal.CallWithToken(Hash, "sum3", int(contract.ReadOnly)).(int) +} + +// JustExecute invokes ` + "`justExecute`" + ` method of contract. +func JustExecute(arr []interface{}) { + neogointernal.CallWithTokenNoRet(Hash, "justExecute", int(contract.All), arr) +} + +// GetPublicKey invokes ` + "`getPublicKey`" + ` method of contract. +func GetPublicKey() interop.PublicKey { + return neogointernal.CallWithToken(Hash, "getPublicKey", int(contract.All)).(interop.PublicKey) +} + +// OtherTypes invokes ` + "`otherTypes`" + ` method of contract. +func OtherTypes(ctr interop.Hash160, tx interop.Hash256, sig interop.Signature, data interface{}) bool { + return neogointernal.CallWithToken(Hash, "otherTypes", int(contract.All), ctr, tx, sig, data).(bool) +} + +// EmptyName invokes ` + "`emptyName`" + ` method of contract. +func EmptyName(arg0 map[string]interface{}) interface{} { + return neogointernal.CallWithToken(Hash, "emptyName", int(contract.All), arg0).(interface{}) +} + +// SearchStorage invokes ` + "`searchStorage`" + ` method of contract. +func SearchStorage(ctx storage.Context) iterator.Iterator { + return neogointernal.CallWithToken(Hash, "searchStorage", int(contract.All), ctx).(iterator.Iterator) +} + +// GetFromMap invokes ` + "`getFromMap`" + ` method of contract. +func GetFromMap(intMap map[string]int, indices []string) []int { + return neogointernal.CallWithToken(Hash, "getFromMap", int(contract.All), intMap, indices).([]int) +} + +// DoSomething invokes ` + "`doSomething`" + ` method of contract. +func DoSomething(bytes []byte, str string) interface{} { + return neogointernal.CallWithToken(Hash, "doSomething", int(contract.ReadStates), bytes, str).(interface{}) +} + +// GetBlockWrapper invokes ` + "`getBlockWrapper`" + ` method of contract. +func GetBlockWrapper() ledger.Block { + return neogointernal.CallWithToken(Hash, "getBlockWrapper", int(contract.All)).(ledger.Block) +} + +// MyFunc invokes ` + "`myFunc`" + ` method of contract. +func MyFunc(in map[int]mycontract.Input) []mycontract.Output { + return neogointernal.CallWithToken(Hash, "myFunc", int(contract.All), in).([]mycontract.Output) +} +` + + data, err := ioutil.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, expected, string(data)) +} + +func TestGenerateValidPackageName(t *testing.T) { + m := manifest.NewManifest("My space\tcontract") + m.ABI.Methods = append(m.ABI.Methods, + manifest.Method{ + Name: "get", + Parameters: []manifest.Parameter{}, + ReturnType: smartcontract.IntegerType, + }, + ) + + manifestFile := filepath.Join(t.TempDir(), "manifest.json") + outFile := filepath.Join(t.TempDir(), "out.go") + + rawManifest, err := json.Marshal(m) + require.NoError(t, err) + require.NoError(t, ioutil.WriteFile(manifestFile, rawManifest, os.ModePerm)) + + h := util.Uint160{ + 0x04, 0x08, 0x15, 0x16, 0x23, 0x42, 0x43, 0x44, 0x00, 0x01, + 0xCA, 0xFE, 0xBA, 0xBE, 0xDE, 0xAD, 0xBE, 0xEF, 0x03, 0x04, + } + app := cli.NewApp() + app.Commands = []cli.Command{generateWrapperCmd} + require.NoError(t, app.Run([]string{"", "generate-wrapper", + "--manifest", manifestFile, + "--out", outFile, + "--hash", h.StringLE(), + })) + + data, err := ioutil.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, `// Package myspacecontract contains wrappers for My space contract contract. +package myspacecontract + +import ( + "github.com/nspcc-dev/neo-go/pkg/interop/contract" + "github.com/nspcc-dev/neo-go/pkg/interop/neogointernal" +) + +// Hash contains contract hash in big-endian form. +const Hash = "\x04\x08\x15\x16\x23\x42\x43\x44\x00\x01\xca\xfe\xba\xbe\xde\xad\xbe\xef\x03\x04" + +// Get invokes `+"`get`"+` method of contract. +func Get() int { + return neogointernal.CallWithToken(Hash, "get", int(contract.All)).(int) +} +`, string(data)) +} + +func TestGenerate_Errors(t *testing.T) { + app := cli.NewApp() + app.Commands = []cli.Command{generateWrapperCmd} + app.ExitErrHandler = func(*cli.Context, error) {} + + checkError := func(t *testing.T, msg string, args ...string) { + // cli.ExitError doesn't implement wraping properly, so we check for an error message. + err := app.Run(append([]string{"", "generate-wrapper"}, args...)) + require.True(t, strings.Contains(err.Error(), msg), "got: %v", err) + } + t.Run("missing manifest argument", func(t *testing.T) { + checkError(t, errNoManifestFile.Error()) + }) + t.Run("missing manifest file", func(t *testing.T) { + checkError(t, "can't read contract manifest", "--manifest", "notexists") + }) + t.Run("invalid manifest", func(t *testing.T) { + manifestFile := filepath.Join(t.TempDir(), "invalid.json") + require.NoError(t, ioutil.WriteFile(manifestFile, []byte("[]"), os.ModePerm)) + checkError(t, "", "--manifest", manifestFile) + }) + + manifestFile := filepath.Join(t.TempDir(), "manifest.json") + m := manifest.NewManifest("MyContract") + rawManifest, err := json.Marshal(m) + require.NoError(t, err) + require.NoError(t, ioutil.WriteFile(manifestFile, rawManifest, os.ModePerm)) + + t.Run("invalid hash", func(t *testing.T) { + checkError(t, "invalid contract hash", "--manifest", manifestFile, "--hash", "xxx") + }) + t.Run("missing config", func(t *testing.T) { + checkError(t, "can't read config file", + "--manifest", manifestFile, "--hash", util.Uint160{}.StringLE(), + "--config", filepath.Join(t.TempDir(), "not.exists.yml")) + }) + t.Run("invalid config", func(t *testing.T) { + rawCfg := `package: wrapper +callflags: + someFunc: ReadSometimes +` + cfgPath := filepath.Join(t.TempDir(), "binding.yml") + require.NoError(t, ioutil.WriteFile(cfgPath, []byte(rawCfg), os.ModePerm)) + + checkError(t, "can't parse config file", + "--manifest", manifestFile, "--hash", util.Uint160{}.StringLE(), + "--config", cfgPath) + }) +} diff --git a/cli/smartcontract/smart_contract.go b/cli/smartcontract/smart_contract.go index 058e0eee5..aa8a19506 100644 --- a/cli/smartcontract/smart_contract.go +++ b/cli/smartcontract/smart_contract.go @@ -179,6 +179,7 @@ func NewCommands() []cli.Command { Action: contractDeploy, Flags: deployFlags, }, + generateWrapperCmd, { Name: "invokefunction", Usage: "invoke deployed contract on the blockchain", diff --git a/pkg/smartcontract/binding/generate.go b/pkg/smartcontract/binding/generate.go new file mode 100644 index 000000000..8791862aa --- /dev/null +++ b/pkg/smartcontract/binding/generate.go @@ -0,0 +1,254 @@ +package binding + +import ( + "bytes" + "fmt" + "io" + "sort" + "strconv" + "strings" + "text/template" + "unicode" + + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +const srcTmpl = ` +{{- define "METHOD" -}} +// {{.Name}} {{.Comment}} +func {{.Name}}({{range $index, $arg := .Arguments -}} + {{- if ne $index 0}}, {{end}} + {{- .Name}} {{.Type}} + {{- end}}) {{if .ReturnType }}{{ .ReturnType }} { + return neogointernal.CallWithToken(Hash, "{{ .NameABI }}", int(contract.{{ .CallFlag }}) + {{- range $arg := .Arguments -}}, {{.Name}}{{end}}).({{ .ReturnType }}) + {{- else -}} { + neogointernal.CallWithTokenNoRet(Hash, "{{ .NameABI }}", int(contract.{{ .CallFlag }}) + {{- range $arg := .Arguments -}}, {{.Name}}{{end}}) + {{- end}} +} +{{- end -}} +// Package {{.PackageName}} contains wrappers for {{.ContractName}} contract. +package {{.PackageName}} + +import ( +{{range $m := .Imports}} "{{ $m }}" +{{end}}) + +// Hash contains contract hash in big-endian form. +const Hash = "{{ .Hash }}" +{{range $m := .Methods}} +{{template "METHOD" $m }} +{{end}}` + +type ( + // Config contains parameter for the generated binding. + Config struct { + Package string `yaml:"package,omitempty"` + Manifest *manifest.Manifest `yaml:"-"` + Hash util.Uint160 `yaml:"hash,omitempty"` + Overrides map[string]Override `yaml:"overrides,omitempty"` + CallFlags map[string]callflag.CallFlag `yaml:"callflags,omitempty"` + Output io.Writer `yaml:"-"` + } + + contractTmpl struct { + PackageName string + ContractName string + Imports []string + Hash string + Methods []methodTmpl + } + + methodTmpl struct { + Name string + NameABI string + CallFlag string + Comment string + Arguments []paramTmpl + ReturnType string + } + + paramTmpl struct { + Name string + Type string + } +) + +// NewConfig initializes and returns new config instance. +func NewConfig() Config { + return Config{ + Overrides: make(map[string]Override), + CallFlags: make(map[string]callflag.CallFlag), + } +} + +// Generate writes Go file containing smartcontract bindings to the `cfg.Output`. +func Generate(cfg Config) error { + ctr, err := templateFromManifest(cfg) + if err != nil { + return err + } + + tmp, err := template.New("generate").Funcs(template.FuncMap{ + "lowerFirst": lowerFirst, + "scTypeToGo": scTypeToGo, + }).Parse(srcTmpl) + if err != nil { + return err + } + + return tmp.Execute(cfg.Output, ctr) +} + +func scTypeToGo(typ smartcontract.ParamType) string { + switch typ { + case smartcontract.AnyType: + return "interface{}" + case smartcontract.BoolType: + return "bool" + case smartcontract.IntegerType: + return "int" + case smartcontract.ByteArrayType: + return "[]byte" + case smartcontract.StringType: + return "string" + case smartcontract.Hash160Type: + return "interop.Hash160" + case smartcontract.Hash256Type: + return "interop.Hash256" + case smartcontract.PublicKeyType: + return "interop.PublicKey" + case smartcontract.SignatureType: + return "interop.Signature" + case smartcontract.ArrayType: + return "[]interface{}" + case smartcontract.MapType: + return "map[string]interface{}" + case smartcontract.InteropInterfaceType: + return "interface{}" + case smartcontract.VoidType: + return "" + default: + panic("unreachable") + } +} + +func templateFromManifest(cfg Config) (contractTmpl, error) { + hStr := "" + for _, b := range cfg.Hash.BytesBE() { + hStr += fmt.Sprintf("\\x%02x", b) + } + + ctr := contractTmpl{ + PackageName: cfg.Package, + ContractName: cfg.Manifest.Name, + Hash: hStr, + } + if ctr.PackageName == "" { + buf := bytes.NewBuffer(make([]byte, 0, len(cfg.Manifest.Name))) + for _, r := range cfg.Manifest.Name { + if unicode.IsLetter(r) { + buf.WriteRune(unicode.ToLower(r)) + } + } + + ctr.PackageName = buf.String() + } + + imports := make(map[string]struct{}) + seen := make(map[string]bool) + for _, m := range cfg.Manifest.ABI.Methods { + seen[m.Name] = false + } + for _, m := range cfg.Manifest.ABI.Methods { + if m.Name[0] == '_' { + continue + } + + imports["github.com/nspcc-dev/neo-go/pkg/interop/contract"] = struct{}{} + imports["github.com/nspcc-dev/neo-go/pkg/interop/neogointernal"] = struct{}{} + + // Consider `perform(a)` and `perform(a, b)` methods. + // First, try to export the second method with `Perform2` name. + // If `perform2` is already in the manifest, use `perform_2` with as many underscores + // as needed to eliminate name conflicts. It will produce long names in certain circumstances, + // but if the manifest contains lots of similar names with trailing underscores, delicate naming + // was probably not the goal. + name := m.Name + if v, ok := seen[name]; !ok || v { + suffix := strconv.Itoa(len(m.Parameters)) + for ; seen[name]; name = m.Name + suffix { + suffix = "_" + suffix + } + } + seen[name] = true + + mtd := methodTmpl{ + Name: upperFirst(name), + NameABI: m.Name, + CallFlag: callflag.All.String(), + Comment: fmt.Sprintf("invokes `%s` method of contract.", m.Name), + } + if f, ok := cfg.CallFlags[m.Name]; ok { + mtd.CallFlag = f.String() + } else if m.Safe { + mtd.CallFlag = callflag.ReadOnly.String() + } + for i := range m.Parameters { + name := m.Parameters[i].Name + if name == "" { + name = fmt.Sprintf("arg%d", i) + } + + var typeStr string + if over, ok := cfg.Overrides[m.Name+"."+name]; ok { + typeStr = over.TypeName + if over.Package != "" { + imports[over.Package] = struct{}{} + } + } else { + typeStr = scTypeToGo(m.Parameters[i].Type) + } + + mtd.Arguments = append(mtd.Arguments, paramTmpl{ + Name: name, + Type: typeStr, + }) + } + + if over, ok := cfg.Overrides[m.Name]; ok { + mtd.ReturnType = over.TypeName + if over.Package != "" { + imports[over.Package] = struct{}{} + } + } else { + mtd.ReturnType = scTypeToGo(m.ReturnType) + switch m.ReturnType { + case smartcontract.Hash160Type, smartcontract.Hash256Type, smartcontract.InteropInterfaceType, + smartcontract.SignatureType, smartcontract.PublicKeyType: + imports["github.com/nspcc-dev/neo-go/pkg/interop"] = struct{}{} + } + } + + ctr.Methods = append(ctr.Methods, mtd) + } + + for imp := range imports { + ctr.Imports = append(ctr.Imports, imp) + } + sort.Strings(ctr.Imports) + + return ctr, nil +} + +func upperFirst(s string) string { + return strings.ToUpper(s[0:1]) + s[1:] +} + +func lowerFirst(s string) string { + return strings.ToLower(s[0:1]) + s[1:] +} diff --git a/pkg/smartcontract/binding/override.go b/pkg/smartcontract/binding/override.go new file mode 100644 index 000000000..632a0f44f --- /dev/null +++ b/pkg/smartcontract/binding/override.go @@ -0,0 +1,75 @@ +package binding + +import ( + "strings" +) + +// Override contains package and type to replace manifest method parameter type with. +type Override struct { + // Package contains fully-qualified package name. + Package string + // TypeName contains type name together with a package alias. + TypeName string +} + +// NewOverrideFromString parses s and returns method parameter type override spec. +func NewOverrideFromString(s string) Override { + var over Override + + index := strings.LastIndexByte(s, '.') + if index == -1 { + over.TypeName = s + return over + } + + // Arrays and maps can have fully-qualified types as elements. + last := strings.LastIndexAny(s, "]*") + isCompound := last != -1 && last < index + if isCompound { + over.Package = s[last+1 : index] + } else { + over.Package = s[:index] + } + + switch over.Package { + case "iterator", "storage": + over.Package = "github.com/nspcc-dev/neo-go/pkg/interop/" + over.Package + case "ledger", "management": + over.Package = "github.com/nspcc-dev/neo-go/pkg/interop/native/" + over.Package + } + + slashIndex := strings.LastIndexByte(s, '/') + if isCompound { + over.TypeName = s[:last+1] + s[slashIndex+1:] + } else { + over.TypeName = s[slashIndex+1:] + } + return over +} + +// UnmarshalYAML implements the YAML Unmarshaler interface. +func (o *Override) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + + err := unmarshal(&s) + if err != nil { + return err + } + + *o = NewOverrideFromString(s) + return err +} + +// MarshalYAML implements the YAML marshaler interface. +func (o Override) MarshalYAML() (interface{}, error) { + if o.Package == "" { + return o.TypeName, nil + } + + index := strings.LastIndexByte(o.TypeName, '.') + last := strings.LastIndexAny(o.TypeName, "]*") + if last == -1 { + return o.Package + o.TypeName[index:], nil + } + return o.TypeName[:last+1] + o.Package + o.TypeName[index:], nil +} diff --git a/pkg/smartcontract/binding/override_test.go b/pkg/smartcontract/binding/override_test.go new file mode 100644 index 000000000..333586510 --- /dev/null +++ b/pkg/smartcontract/binding/override_test.go @@ -0,0 +1,32 @@ +package binding + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewOverrideFromString(t *testing.T) { + testCases := []struct { + expected Override + value string + }{ + {Override{"import.com/pkg", "pkg.Type"}, "import.com/pkg.Type"}, + {Override{"", "map[int]int"}, "map[int]int"}, + {Override{"", "[]int"}, "[]int"}, + {Override{"", "map[int][]int"}, "map[int][]int"}, + {Override{"import.com/pkg", "map[int]pkg.Type"}, "map[int]import.com/pkg.Type"}, + {Override{"import.com/pkg", "[]pkg.Type"}, "[]import.com/pkg.Type"}, + {Override{"import.com/pkg", "map[int]*pkg.Type"}, "map[int]*import.com/pkg.Type"}, + {Override{"import.com/pkg", "[]*pkg.Type"}, "[]*import.com/pkg.Type"}, + {Override{"import.com/pkg", "[][]*pkg.Type"}, "[][]*import.com/pkg.Type"}, + {Override{"import.com/pkg", "map[string][]pkg.Type"}, "map[string][]import.com/pkg.Type"}} + + for _, tc := range testCases { + require.Equal(t, tc.expected, NewOverrideFromString(tc.value)) + + s, err := tc.expected.MarshalYAML() + require.NoError(t, err) + require.Equal(t, tc.value, s) + } +}