From c058ab5604392b4454f269892c96ab0e3a1e6573 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Fri, 2 Dec 2022 10:20:55 +0300 Subject: [PATCH] rpcbinding: handle more complex non-structured types --- cli/smartcontract/testdata/types/config.yml | 2 +- .../testdata/types/rpcbindings.out | 199 ++++++++++++++++++ cli/smartcontract/testdata/types/types.go | 12 ++ pkg/smartcontract/rpcbinding/binding.go | 117 ++++++---- 4 files changed, 285 insertions(+), 45 deletions(-) diff --git a/cli/smartcontract/testdata/types/config.yml b/cli/smartcontract/testdata/types/config.yml index 52ca5bbfc..a36d2e655 100644 --- a/cli/smartcontract/testdata/types/config.yml +++ b/cli/smartcontract/testdata/types/config.yml @@ -1,3 +1,3 @@ name: "Types" sourceurl: https://github.com/nspcc-dev/neo-go/ -safemethods: ["bool", "int", "bytes", "string", "hash160", "hash256", "publicKey", "signature", "bools", "ints", "bytess", "strings", "hash160s", "hash256s", "publicKeys", "signatures"] +safemethods: ["bool", "int", "bytes", "string", "hash160", "hash256", "publicKey", "signature", "bools", "ints", "bytess", "strings", "hash160s", "hash256s", "publicKeys", "signatures", "aAAStrings", "maps", "crazyMaps"] diff --git a/cli/smartcontract/testdata/types/rpcbindings.out b/cli/smartcontract/testdata/types/rpcbindings.out index 3f7cbb48d..65df997c9 100644 --- a/cli/smartcontract/testdata/types/rpcbindings.out +++ b/cli/smartcontract/testdata/types/rpcbindings.out @@ -2,11 +2,15 @@ package types import ( + "errors" + "fmt" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "math/big" + "unicode/utf8" ) // Hash contains contract hash. @@ -28,6 +32,64 @@ func NewReader(invoker Invoker) *ContractReader { } +// AAAStrings invokes `aAAStrings` method of contract. +func (c *ContractReader) AAAStrings(s [][][]string) ([][][]string, error) { + return func (item stackitem.Item, err error) ([][][]string, error) { + if err != nil { + return nil, err + } + return func (item stackitem.Item) ([][][]string, error) { + arr, ok := item.Value().([]stackitem.Item) + if !ok { + return nil, errors.New("not an array") + } + res := make([][][]string, len(arr)) + for i := range res { + res[i], err = func (item stackitem.Item) ([][]string, error) { + arr, ok := item.Value().([]stackitem.Item) + if !ok { + return nil, errors.New("not an array") + } + res := make([][]string, len(arr)) + for i := range res { + res[i], err = func (item stackitem.Item) ([]string, error) { + arr, ok := item.Value().([]stackitem.Item) + if !ok { + return nil, errors.New("not an array") + } + res := make([]string, len(arr)) + for i := range res { + res[i], err = func (item stackitem.Item) (string, error) { + b, err := item.TryBytes() + if err != nil { + return "", err + } + if !utf8.Valid(b) { + return "", errors.New("not a UTF-8 string") + } + return string(b), nil + } (arr[i]) + if err != nil { + return nil, err + } + } + return res, nil + } (arr[i]) + if err != nil { + return nil, err + } + } + return res, nil + } (arr[i]) + if err != nil { + return nil, err + } + } + return res, nil + } (item) + } (unwrap.Item(c.invoker.Call(Hash, "aAAStrings", s))) +} + // Bool invokes `bool` method of contract. func (c *ContractReader) Bool(b bool) (bool, error) { return unwrap.Bool(c.invoker.Call(Hash, "bool", b)) @@ -48,6 +110,97 @@ func (c *ContractReader) Bytess(b [][]byte) ([][]byte, error) { return unwrap.ArrayOfBytes(c.invoker.Call(Hash, "bytess", b)) } +// CrazyMaps invokes `crazyMaps` method of contract. +func (c *ContractReader) CrazyMaps(m map[*big.Int][]map[string][]util.Uint160) (map[*big.Int][]map[string][]util.Uint160, error) { + return func (item stackitem.Item, err error) (map[*big.Int][]map[string][]util.Uint160, error) { + if err != nil { + return nil, err + } + return func (item stackitem.Item) (map[*big.Int][]map[string][]util.Uint160, error) { + m, ok := item.Value().([]stackitem.MapElement) + if !ok { + return nil, fmt.Errorf("%s is not a map", item.Type().String()) + } + res := make(map[*big.Int][]map[string][]util.Uint160) + for i := range m { + k, err := m[i].Key.TryInteger() + if err != nil { + return nil, err + } + v, err := func (item stackitem.Item) ([]map[string][]util.Uint160, error) { + arr, ok := item.Value().([]stackitem.Item) + if !ok { + return nil, errors.New("not an array") + } + res := make([]map[string][]util.Uint160, len(arr)) + for i := range res { + res[i], err = func (item stackitem.Item) (map[string][]util.Uint160, error) { + m, ok := item.Value().([]stackitem.MapElement) + if !ok { + return nil, fmt.Errorf("%s is not a map", item.Type().String()) + } + res := make(map[string][]util.Uint160) + for i := range m { + k, err := func (item stackitem.Item) (string, error) { + b, err := item.TryBytes() + if err != nil { + return "", err + } + if !utf8.Valid(b) { + return "", errors.New("not a UTF-8 string") + } + return string(b), nil + } (m[i].Key) + if err != nil { + return nil, err + } + v, err := func (item stackitem.Item) ([]util.Uint160, error) { + arr, ok := item.Value().([]stackitem.Item) + if !ok { + return nil, errors.New("not an array") + } + res := make([]util.Uint160, len(arr)) + for i := range res { + res[i], err = func (item stackitem.Item) (util.Uint160, error) { + b, err := item.TryBytes() + if err != nil { + return util.Uint160{}, err + } + u, err := util.Uint160DecodeBytesBE(b) + if err != nil { + return util.Uint160{}, err + } + return u, nil + } (arr[i]) + if err != nil { + return nil, err + } + } + return res, nil + } (m[i].Value) + if err != nil { + return nil, err + } + res[k] = v + } + return res, nil + } (arr[i]) + if err != nil { + return nil, err + } + } + return res, nil + } (m[i].Value) + if err != nil { + return nil, err + } + res[k] = v + } + return res, nil + } (item) + } (unwrap.Item(c.invoker.Call(Hash, "crazyMaps", m))) +} + // Hash160 invokes `hash160` method of contract. func (c *ContractReader) Hash160(h util.Uint160) (util.Uint160, error) { return unwrap.Uint160(c.invoker.Call(Hash, "hash160", h)) @@ -78,6 +231,52 @@ func (c *ContractReader) Ints(i []*big.Int) ([]*big.Int, error) { return unwrap.ArrayOfBigInts(c.invoker.Call(Hash, "ints", i)) } +// Maps invokes `maps` method of contract. +func (c *ContractReader) Maps(m map[string]string) (map[string]string, error) { + return func (item stackitem.Item, err error) (map[string]string, error) { + if err != nil { + return nil, err + } + return func (item stackitem.Item) (map[string]string, error) { + m, ok := item.Value().([]stackitem.MapElement) + if !ok { + return nil, fmt.Errorf("%s is not a map", item.Type().String()) + } + res := make(map[string]string) + for i := range m { + k, err := func (item stackitem.Item) (string, error) { + b, err := item.TryBytes() + if err != nil { + return "", err + } + if !utf8.Valid(b) { + return "", errors.New("not a UTF-8 string") + } + return string(b), nil + } (m[i].Key) + if err != nil { + return nil, err + } + v, err := func (item stackitem.Item) (string, error) { + b, err := item.TryBytes() + if err != nil { + return "", err + } + if !utf8.Valid(b) { + return "", errors.New("not a UTF-8 string") + } + return string(b), nil + } (m[i].Value) + if err != nil { + return nil, err + } + res[k] = v + } + return res, nil + } (item) + } (unwrap.Item(c.invoker.Call(Hash, "maps", m))) +} + // PublicKey invokes `publicKey` method of contract. func (c *ContractReader) PublicKey(k *keys.PublicKey) (*keys.PublicKey, error) { return unwrap.PublicKey(c.invoker.Call(Hash, "publicKey", k)) diff --git a/cli/smartcontract/testdata/types/types.go b/cli/smartcontract/testdata/types/types.go index dcd52ae90..2d97dd081 100644 --- a/cli/smartcontract/testdata/types/types.go +++ b/cli/smartcontract/testdata/types/types.go @@ -67,3 +67,15 @@ func PublicKeys(k []interop.PublicKey) []interop.PublicKey { func Signatures(s []interop.Signature) []interop.Signature { return nil } + +func AAAStrings(s [][][]string) [][][]string { + return s +} + +func Maps(m map[string]string) map[string]string { + return m +} + +func CrazyMaps(m map[int][]map[string][]interop.Hash160) map[int][]map[string][]interop.Hash160 { + return m +} diff --git a/pkg/smartcontract/rpcbinding/binding.go b/pkg/smartcontract/rpcbinding/binding.go index 6b10a5163..e89988041 100644 --- a/pkg/smartcontract/rpcbinding/binding.go +++ b/pkg/smartcontract/rpcbinding/binding.go @@ -19,18 +19,20 @@ func (c *ContractReader) {{.Name}}({{range $index, $arg := .Arguments -}} {{- if ne $index 0}}, {{end}} {{- .Name}} {{.Type}} {{- end}}) {{if .ReturnType }}({{ .ReturnType }}, error) { - return {{if .CallFlag -}} - unwrap.{{.CallFlag}}(c.invoker.Call(Hash, "{{ .NameABI }}"{{/* CallFlag field is used for function name */}} - {{- else -}} - itemTo{{ cutPointer .ReturnType }}(unwrap.Item(c.invoker.Call(Hash, "{{ .NameABI }}"{{/* CallFlag field is used for function name */}} - {{- end -}} - {{- range $arg := .Arguments -}}, {{.Name}}{{end}})){{if not .CallFlag}}){{end}} + return {{if and (not .ItemTo) (eq .Unwrapper "Item")}}func (item stackitem.Item, err error) ({{ .ReturnType }}, error) { + if err != nil { + return nil, err + } + return {{etTypeConverter .ExtendedReturn "item"}} + } ( {{- end -}} {{if .ItemTo -}} itemTo{{ .ItemTo }}( {{- end -}} + unwrap.{{.Unwrapper}}(c.invoker.Call(Hash, "{{ .NameABI }}" + {{- range $arg := .Arguments -}}, {{.Name}}{{end -}} )) {{- if or .ItemTo (eq .Unwrapper "Item") -}} ) {{- end}} {{- else -}} (*result.Invoke, error) { c.invoker.Call(Hash, "{{ .NameABI }}" {{- range $arg := .Arguments -}}, {{.Name}}{{end}}) {{- end}} } -{{- if eq .CallFlag "SessionIterator"}} +{{- if eq .Unwrapper "SessionIterator"}} // {{.Name}}Expanded is similar to {{.Name}} (uses the same contract // method), but can be useful if the server used doesn't support sessions and @@ -245,7 +247,7 @@ type ( ContractTmpl struct { binding.ContractTmpl - SafeMethods []binding.MethodTmpl + SafeMethods []SafeMethodTmpl NamedTypes map[string]binding.ExtendedType IsNep11D bool @@ -256,6 +258,13 @@ type ( HasWriter bool HasIterator bool } + + SafeMethodTmpl struct { + binding.MethodTmpl + Unwrapper string + ItemTo string + ExtendedReturn binding.ExtendedType + } ) // NewConfig initializes and returns a new config instance. @@ -518,7 +527,15 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st for i := 0; i < len(ctr.Methods); i++ { abim := cfg.Manifest.ABI.GetMethod(ctr.Methods[i].NameABI, len(ctr.Methods[i].Arguments)) if abim.Safe { - ctr.SafeMethods = append(ctr.SafeMethods, ctr.Methods[i]) + ctr.SafeMethods = append(ctr.SafeMethods, SafeMethodTmpl{MethodTmpl: ctr.Methods[i]}) + et, ok := cfg.Types[abim.Name] + if ok { + ctr.SafeMethods[len(ctr.SafeMethods)-1].ExtendedReturn = et + if abim.ReturnType == smartcontract.ArrayType && len(et.Name) > 0 { + imports["errors"] = struct{}{} + ctr.SafeMethods[len(ctr.SafeMethods)-1].ItemTo = cutPointer(ctr.Methods[i].ReturnType) + } + } ctr.Methods = append(ctr.Methods[:i], ctr.Methods[i+1:]...) i-- } else { @@ -529,27 +546,12 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st } } for _, et := range cfg.NamedTypes { - for _, fet := range et.Fields { - _, pkg := extendedTypeToGo(fet.ExtendedType, ctr.NamedTypes) - if pkg != "" { - imports[pkg] = struct{}{} - } - // Additional packages used during decoding. - switch fet.Base { - case smartcontract.StringType: - imports["unicode/utf8"] = struct{}{} - case smartcontract.PublicKeyType: - imports["crypto/elliptic"] = struct{}{} - case smartcontract.MapType: - imports["fmt"] = struct{}{} - } - } + addETImports(et, ctr.NamedTypes, imports) } if len(cfg.NamedTypes) > 0 { imports["errors"] = struct{}{} } - // We're misusing CallFlag field for function name here. for i := range ctr.SafeMethods { switch ctr.SafeMethods[i].ReturnType { case "interface{}": @@ -559,49 +561,50 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{} imports["github.com/nspcc-dev/neo-go/pkg/neorpc/result"] = struct{}{} ctr.SafeMethods[i].ReturnType = "uuid.UUID, result.Iterator" - ctr.SafeMethods[i].CallFlag = "SessionIterator" + ctr.SafeMethods[i].Unwrapper = "SessionIterator" ctr.HasIterator = true } else { imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{} ctr.SafeMethods[i].ReturnType = "stackitem.Item" - ctr.SafeMethods[i].CallFlag = "Item" + ctr.SafeMethods[i].Unwrapper = "Item" } case "bool": - ctr.SafeMethods[i].CallFlag = "Bool" + ctr.SafeMethods[i].Unwrapper = "Bool" case "*big.Int": - ctr.SafeMethods[i].CallFlag = "BigInt" + ctr.SafeMethods[i].Unwrapper = "BigInt" case "string": - ctr.SafeMethods[i].CallFlag = "UTF8String" + ctr.SafeMethods[i].Unwrapper = "UTF8String" case "util.Uint160": - ctr.SafeMethods[i].CallFlag = "Uint160" + ctr.SafeMethods[i].Unwrapper = "Uint160" case "util.Uint256": - ctr.SafeMethods[i].CallFlag = "Uint256" + ctr.SafeMethods[i].Unwrapper = "Uint256" case "*keys.PublicKey": - ctr.SafeMethods[i].CallFlag = "PublicKey" + ctr.SafeMethods[i].Unwrapper = "PublicKey" case "[]byte": - ctr.SafeMethods[i].CallFlag = "Bytes" + ctr.SafeMethods[i].Unwrapper = "Bytes" case "[]interface{}": imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{} ctr.SafeMethods[i].ReturnType = "[]stackitem.Item" - ctr.SafeMethods[i].CallFlag = "Array" + ctr.SafeMethods[i].Unwrapper = "Array" case "*stackitem.Map": - ctr.SafeMethods[i].CallFlag = "Map" + ctr.SafeMethods[i].Unwrapper = "Map" case "[]bool": - ctr.SafeMethods[i].CallFlag = "ArrayOfBools" + ctr.SafeMethods[i].Unwrapper = "ArrayOfBools" case "[]*big.Int": - ctr.SafeMethods[i].CallFlag = "ArrayOfBigInts" + ctr.SafeMethods[i].Unwrapper = "ArrayOfBigInts" case "[][]byte": - ctr.SafeMethods[i].CallFlag = "ArrayOfBytes" + ctr.SafeMethods[i].Unwrapper = "ArrayOfBytes" case "[]string": - ctr.SafeMethods[i].CallFlag = "ArrayOfUTF8Strings" + ctr.SafeMethods[i].Unwrapper = "ArrayOfUTF8Strings" case "[]util.Uint160": - ctr.SafeMethods[i].CallFlag = "ArrayOfUint160" + ctr.SafeMethods[i].Unwrapper = "ArrayOfUint160" case "[]util.Uint256": - ctr.SafeMethods[i].CallFlag = "ArrayOfUint256" + ctr.SafeMethods[i].Unwrapper = "ArrayOfUint256" case "keys.PublicKeys": - ctr.SafeMethods[i].CallFlag = "ArrayOfPublicKeys" + ctr.SafeMethods[i].Unwrapper = "ArrayOfPublicKeys" default: - ctr.SafeMethods[i].CallFlag = "" + addETImports(ctr.SafeMethods[i].ExtendedReturn, ctr.NamedTypes, imports) + ctr.SafeMethods[i].Unwrapper = "Item" } } @@ -629,6 +632,32 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st return ctr } +func addETImports(et binding.ExtendedType, named map[string]binding.ExtendedType, imports map[string]struct{}) { + _, pkg := extendedTypeToGo(et, named) + if pkg != "" { + imports[pkg] = struct{}{} + } + // Additional packages used during decoding. + switch et.Base { + case smartcontract.StringType: + imports["unicode/utf8"] = struct{}{} + imports["errors"] = struct{}{} + case smartcontract.PublicKeyType: + imports["crypto/elliptic"] = struct{}{} + case smartcontract.MapType: + imports["fmt"] = struct{}{} + } + if et.Value != nil { + addETImports(*et.Value, named, imports) + } + if et.Base == smartcontract.MapType { + addETImports(binding.ExtendedType{Base: et.Key}, named, imports) + } + for i := range et.Fields { + addETImports(et.Fields[i].ExtendedType, named, imports) + } +} + func cutPointer(s string) string { if s[0] == '*' { return s[1:]