rpcbinding: handle more complex non-structured types

This commit is contained in:
Roman Khimov 2022-12-02 10:20:55 +03:00
parent ce67e6795e
commit c058ab5604
4 changed files with 285 additions and 45 deletions

View file

@ -1,3 +1,3 @@
name: "Types"
sourceurl: https://github.com/nspcc-dev/neo-go/
safemethods: ["bool", "int", "bytes", "string", "hash160", "hash256", "publicKey", "signature", "bools", "ints", "bytess", "strings", "hash160s", "hash256s", "publicKeys", "signatures"]
safemethods: ["bool", "int", "bytes", "string", "hash160", "hash256", "publicKey", "signature", "bools", "ints", "bytess", "strings", "hash160s", "hash256s", "publicKeys", "signatures", "aAAStrings", "maps", "crazyMaps"]

View file

@ -2,11 +2,15 @@
package types
import (
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/neorpc/result"
"github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"math/big"
"unicode/utf8"
)
// Hash contains contract hash.
@ -28,6 +32,64 @@ func NewReader(invoker Invoker) *ContractReader {
}
// AAAStrings invokes `aAAStrings` method of contract.
func (c *ContractReader) AAAStrings(s [][][]string) ([][][]string, error) {
return func (item stackitem.Item, err error) ([][][]string, error) {
if err != nil {
return nil, err
}
return func (item stackitem.Item) ([][][]string, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([][][]string, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) ([][]string, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([][]string, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) ([]string, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([]string, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (arr[i])
if err != nil {
return nil, err
}
}
return res, nil
} (arr[i])
if err != nil {
return nil, err
}
}
return res, nil
} (arr[i])
if err != nil {
return nil, err
}
}
return res, nil
} (item)
} (unwrap.Item(c.invoker.Call(Hash, "aAAStrings", s)))
}
// Bool invokes `bool` method of contract.
func (c *ContractReader) Bool(b bool) (bool, error) {
return unwrap.Bool(c.invoker.Call(Hash, "bool", b))
@ -48,6 +110,97 @@ func (c *ContractReader) Bytess(b [][]byte) ([][]byte, error) {
return unwrap.ArrayOfBytes(c.invoker.Call(Hash, "bytess", b))
}
// CrazyMaps invokes `crazyMaps` method of contract.
func (c *ContractReader) CrazyMaps(m map[*big.Int][]map[string][]util.Uint160) (map[*big.Int][]map[string][]util.Uint160, error) {
return func (item stackitem.Item, err error) (map[*big.Int][]map[string][]util.Uint160, error) {
if err != nil {
return nil, err
}
return func (item stackitem.Item) (map[*big.Int][]map[string][]util.Uint160, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(map[*big.Int][]map[string][]util.Uint160)
for i := range m {
k, err := m[i].Key.TryInteger()
if err != nil {
return nil, err
}
v, err := func (item stackitem.Item) ([]map[string][]util.Uint160, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([]map[string][]util.Uint160, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) (map[string][]util.Uint160, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(map[string][]util.Uint160)
for i := range m {
k, err := func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (m[i].Key)
if err != nil {
return nil, err
}
v, err := func (item stackitem.Item) ([]util.Uint160, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([]util.Uint160, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) (util.Uint160, error) {
b, err := item.TryBytes()
if err != nil {
return util.Uint160{}, err
}
u, err := util.Uint160DecodeBytesBE(b)
if err != nil {
return util.Uint160{}, err
}
return u, nil
} (arr[i])
if err != nil {
return nil, err
}
}
return res, nil
} (m[i].Value)
if err != nil {
return nil, err
}
res[k] = v
}
return res, nil
} (arr[i])
if err != nil {
return nil, err
}
}
return res, nil
} (m[i].Value)
if err != nil {
return nil, err
}
res[k] = v
}
return res, nil
} (item)
} (unwrap.Item(c.invoker.Call(Hash, "crazyMaps", m)))
}
// Hash160 invokes `hash160` method of contract.
func (c *ContractReader) Hash160(h util.Uint160) (util.Uint160, error) {
return unwrap.Uint160(c.invoker.Call(Hash, "hash160", h))
@ -78,6 +231,52 @@ func (c *ContractReader) Ints(i []*big.Int) ([]*big.Int, error) {
return unwrap.ArrayOfBigInts(c.invoker.Call(Hash, "ints", i))
}
// Maps invokes `maps` method of contract.
func (c *ContractReader) Maps(m map[string]string) (map[string]string, error) {
return func (item stackitem.Item, err error) (map[string]string, error) {
if err != nil {
return nil, err
}
return func (item stackitem.Item) (map[string]string, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(map[string]string)
for i := range m {
k, err := func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (m[i].Key)
if err != nil {
return nil, err
}
v, err := func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (m[i].Value)
if err != nil {
return nil, err
}
res[k] = v
}
return res, nil
} (item)
} (unwrap.Item(c.invoker.Call(Hash, "maps", m)))
}
// PublicKey invokes `publicKey` method of contract.
func (c *ContractReader) PublicKey(k *keys.PublicKey) (*keys.PublicKey, error) {
return unwrap.PublicKey(c.invoker.Call(Hash, "publicKey", k))

View file

@ -67,3 +67,15 @@ func PublicKeys(k []interop.PublicKey) []interop.PublicKey {
func Signatures(s []interop.Signature) []interop.Signature {
return nil
}
func AAAStrings(s [][][]string) [][][]string {
return s
}
func Maps(m map[string]string) map[string]string {
return m
}
func CrazyMaps(m map[int][]map[string][]interop.Hash160) map[int][]map[string][]interop.Hash160 {
return m
}

View file

@ -19,18 +19,20 @@ func (c *ContractReader) {{.Name}}({{range $index, $arg := .Arguments -}}
{{- if ne $index 0}}, {{end}}
{{- .Name}} {{.Type}}
{{- end}}) {{if .ReturnType }}({{ .ReturnType }}, error) {
return {{if .CallFlag -}}
unwrap.{{.CallFlag}}(c.invoker.Call(Hash, "{{ .NameABI }}"{{/* CallFlag field is used for function name */}}
{{- else -}}
itemTo{{ cutPointer .ReturnType }}(unwrap.Item(c.invoker.Call(Hash, "{{ .NameABI }}"{{/* CallFlag field is used for function name */}}
{{- end -}}
{{- range $arg := .Arguments -}}, {{.Name}}{{end}})){{if not .CallFlag}}){{end}}
return {{if and (not .ItemTo) (eq .Unwrapper "Item")}}func (item stackitem.Item, err error) ({{ .ReturnType }}, error) {
if err != nil {
return nil, err
}
return {{etTypeConverter .ExtendedReturn "item"}}
} ( {{- end -}} {{if .ItemTo -}} itemTo{{ .ItemTo }}( {{- end -}}
unwrap.{{.Unwrapper}}(c.invoker.Call(Hash, "{{ .NameABI }}"
{{- range $arg := .Arguments -}}, {{.Name}}{{end -}} )) {{- if or .ItemTo (eq .Unwrapper "Item") -}} ) {{- end}}
{{- else -}} (*result.Invoke, error) {
c.invoker.Call(Hash, "{{ .NameABI }}"
{{- range $arg := .Arguments -}}, {{.Name}}{{end}})
{{- end}}
}
{{- if eq .CallFlag "SessionIterator"}}
{{- if eq .Unwrapper "SessionIterator"}}
// {{.Name}}Expanded is similar to {{.Name}} (uses the same contract
// method), but can be useful if the server used doesn't support sessions and
@ -245,7 +247,7 @@ type (
ContractTmpl struct {
binding.ContractTmpl
SafeMethods []binding.MethodTmpl
SafeMethods []SafeMethodTmpl
NamedTypes map[string]binding.ExtendedType
IsNep11D bool
@ -256,6 +258,13 @@ type (
HasWriter bool
HasIterator bool
}
SafeMethodTmpl struct {
binding.MethodTmpl
Unwrapper string
ItemTo string
ExtendedReturn binding.ExtendedType
}
)
// NewConfig initializes and returns a new config instance.
@ -518,7 +527,15 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
for i := 0; i < len(ctr.Methods); i++ {
abim := cfg.Manifest.ABI.GetMethod(ctr.Methods[i].NameABI, len(ctr.Methods[i].Arguments))
if abim.Safe {
ctr.SafeMethods = append(ctr.SafeMethods, ctr.Methods[i])
ctr.SafeMethods = append(ctr.SafeMethods, SafeMethodTmpl{MethodTmpl: ctr.Methods[i]})
et, ok := cfg.Types[abim.Name]
if ok {
ctr.SafeMethods[len(ctr.SafeMethods)-1].ExtendedReturn = et
if abim.ReturnType == smartcontract.ArrayType && len(et.Name) > 0 {
imports["errors"] = struct{}{}
ctr.SafeMethods[len(ctr.SafeMethods)-1].ItemTo = cutPointer(ctr.Methods[i].ReturnType)
}
}
ctr.Methods = append(ctr.Methods[:i], ctr.Methods[i+1:]...)
i--
} else {
@ -529,27 +546,12 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
}
}
for _, et := range cfg.NamedTypes {
for _, fet := range et.Fields {
_, pkg := extendedTypeToGo(fet.ExtendedType, ctr.NamedTypes)
if pkg != "" {
imports[pkg] = struct{}{}
}
// Additional packages used during decoding.
switch fet.Base {
case smartcontract.StringType:
imports["unicode/utf8"] = struct{}{}
case smartcontract.PublicKeyType:
imports["crypto/elliptic"] = struct{}{}
case smartcontract.MapType:
imports["fmt"] = struct{}{}
}
}
addETImports(et, ctr.NamedTypes, imports)
}
if len(cfg.NamedTypes) > 0 {
imports["errors"] = struct{}{}
}
// We're misusing CallFlag field for function name here.
for i := range ctr.SafeMethods {
switch ctr.SafeMethods[i].ReturnType {
case "interface{}":
@ -559,49 +561,50 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{}
imports["github.com/nspcc-dev/neo-go/pkg/neorpc/result"] = struct{}{}
ctr.SafeMethods[i].ReturnType = "uuid.UUID, result.Iterator"
ctr.SafeMethods[i].CallFlag = "SessionIterator"
ctr.SafeMethods[i].Unwrapper = "SessionIterator"
ctr.HasIterator = true
} else {
imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{}
ctr.SafeMethods[i].ReturnType = "stackitem.Item"
ctr.SafeMethods[i].CallFlag = "Item"
ctr.SafeMethods[i].Unwrapper = "Item"
}
case "bool":
ctr.SafeMethods[i].CallFlag = "Bool"
ctr.SafeMethods[i].Unwrapper = "Bool"
case "*big.Int":
ctr.SafeMethods[i].CallFlag = "BigInt"
ctr.SafeMethods[i].Unwrapper = "BigInt"
case "string":
ctr.SafeMethods[i].CallFlag = "UTF8String"
ctr.SafeMethods[i].Unwrapper = "UTF8String"
case "util.Uint160":
ctr.SafeMethods[i].CallFlag = "Uint160"
ctr.SafeMethods[i].Unwrapper = "Uint160"
case "util.Uint256":
ctr.SafeMethods[i].CallFlag = "Uint256"
ctr.SafeMethods[i].Unwrapper = "Uint256"
case "*keys.PublicKey":
ctr.SafeMethods[i].CallFlag = "PublicKey"
ctr.SafeMethods[i].Unwrapper = "PublicKey"
case "[]byte":
ctr.SafeMethods[i].CallFlag = "Bytes"
ctr.SafeMethods[i].Unwrapper = "Bytes"
case "[]interface{}":
imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{}
ctr.SafeMethods[i].ReturnType = "[]stackitem.Item"
ctr.SafeMethods[i].CallFlag = "Array"
ctr.SafeMethods[i].Unwrapper = "Array"
case "*stackitem.Map":
ctr.SafeMethods[i].CallFlag = "Map"
ctr.SafeMethods[i].Unwrapper = "Map"
case "[]bool":
ctr.SafeMethods[i].CallFlag = "ArrayOfBools"
ctr.SafeMethods[i].Unwrapper = "ArrayOfBools"
case "[]*big.Int":
ctr.SafeMethods[i].CallFlag = "ArrayOfBigInts"
ctr.SafeMethods[i].Unwrapper = "ArrayOfBigInts"
case "[][]byte":
ctr.SafeMethods[i].CallFlag = "ArrayOfBytes"
ctr.SafeMethods[i].Unwrapper = "ArrayOfBytes"
case "[]string":
ctr.SafeMethods[i].CallFlag = "ArrayOfUTF8Strings"
ctr.SafeMethods[i].Unwrapper = "ArrayOfUTF8Strings"
case "[]util.Uint160":
ctr.SafeMethods[i].CallFlag = "ArrayOfUint160"
ctr.SafeMethods[i].Unwrapper = "ArrayOfUint160"
case "[]util.Uint256":
ctr.SafeMethods[i].CallFlag = "ArrayOfUint256"
ctr.SafeMethods[i].Unwrapper = "ArrayOfUint256"
case "keys.PublicKeys":
ctr.SafeMethods[i].CallFlag = "ArrayOfPublicKeys"
ctr.SafeMethods[i].Unwrapper = "ArrayOfPublicKeys"
default:
ctr.SafeMethods[i].CallFlag = ""
addETImports(ctr.SafeMethods[i].ExtendedReturn, ctr.NamedTypes, imports)
ctr.SafeMethods[i].Unwrapper = "Item"
}
}
@ -629,6 +632,32 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
return ctr
}
func addETImports(et binding.ExtendedType, named map[string]binding.ExtendedType, imports map[string]struct{}) {
_, pkg := extendedTypeToGo(et, named)
if pkg != "" {
imports[pkg] = struct{}{}
}
// Additional packages used during decoding.
switch et.Base {
case smartcontract.StringType:
imports["unicode/utf8"] = struct{}{}
imports["errors"] = struct{}{}
case smartcontract.PublicKeyType:
imports["crypto/elliptic"] = struct{}{}
case smartcontract.MapType:
imports["fmt"] = struct{}{}
}
if et.Value != nil {
addETImports(*et.Value, named, imports)
}
if et.Base == smartcontract.MapType {
addETImports(binding.ExtendedType{Base: et.Key}, named, imports)
}
for i := range et.Fields {
addETImports(et.Fields[i].ExtendedType, named, imports)
}
}
func cutPointer(s string) string {
if s[0] == '*' {
return s[1:]