rpcbinding: generate bindings using new extended type data

This commit is contained in:
Roman Khimov 2022-11-25 13:04:24 +03:00
parent 2bcf3a4ad5
commit 0214628127
7 changed files with 1687 additions and 38 deletions

View file

@ -3,6 +3,7 @@ package rpcbinding
import (
"fmt"
"sort"
"strings"
"text/template"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
@ -18,8 +19,12 @@ func (c *ContractReader) {{.Name}}({{range $index, $arg := .Arguments -}}
{{- if ne $index 0}}, {{end}}
{{- .Name}} {{.Type}}
{{- end}}) {{if .ReturnType }}({{ .ReturnType }}, error) {
return unwrap.{{.CallFlag}}(c.invoker.Call(Hash, "{{ .NameABI }}"{{/* CallFlag field is used for function name */}}
{{- range $arg := .Arguments -}}, {{.Name}}{{end}}))
return {{if .CallFlag -}}
unwrap.{{.CallFlag}}(c.invoker.Call(Hash, "{{ .NameABI }}"{{/* CallFlag field is used for function name */}}
{{- else -}}
itemTo{{ cutPointer .ReturnType }}(unwrap.Item(c.invoker.Call(Hash, "{{ .NameABI }}"{{/* CallFlag field is used for function name */}}
{{- end -}}
{{- range $arg := .Arguments -}}, {{.Name}}{{end}})){{if not .CallFlag}}){{end}}
{{- else -}} (*result.Invoke, error) {
c.invoker.Call(Hash, "{{ .NameABI }}"
{{- range $arg := .Arguments -}}, {{.Name}}{{end}})
@ -101,6 +106,14 @@ import (
// Hash contains contract hash.
var Hash = {{ .Hash }}
{{range $name, $typ := .NamedTypes}}
// {{toTypeName $name}} is a contract-specific {{$name}} type used by its methods.
type {{toTypeName $name}} struct {
{{- range $m := $typ.Fields}}
{{.Field}} {{etTypeToStr .ExtendedType}}
{{- end}}
}
{{end -}}
{{if .HasReader}}// Invoker is used by ContractReader to call various safe methods.
type Invoker interface {
{{if or .IsNep11D .IsNep11ND}} nep11.Invoker
@ -199,15 +212,41 @@ func New(actor Actor) *Contract {
{{end}}
{{- range $m := .Methods}}
{{template "METHOD" $m }}
{{end}}`
{{end}}
{{- range $name, $typ := .NamedTypes}}
// itemTo{{toTypeName $name}} converts stack item into *{{toTypeName $name}}.
func itemTo{{toTypeName $name}}(item stackitem.Item, err error) (*{{toTypeName $name}}, error) {
if err != nil {
return nil, err
}
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
if len(arr) != {{len $typ.Fields}} {
return nil, errors.New("wrong number of structure elements")
}
var srcTemplate = template.Must(template.New("generate").Parse(srcTmpl))
var res = new({{toTypeName $name}})
{{if len .Fields}} var index = -1
{{- range $m := $typ.Fields}}
index++
res.{{.Field}}, err = {{etTypeConverter .ExtendedType "arr[index]"}}
if err != nil {
return nil, err
}
{{end}}
{{end}}
return res, err
}
{{end}}`
type (
ContractTmpl struct {
binding.ContractTmpl
SafeMethods []binding.MethodTmpl
NamedTypes map[string]binding.ExtendedType
IsNep11D bool
IsNep11ND bool
@ -268,6 +307,17 @@ func Generate(cfg binding.Config) error {
ctr.ContractTmpl = binding.TemplateFromManifest(cfg, scTypeToGo)
ctr = scTemplateToRPC(cfg, ctr, imports)
ctr.NamedTypes = cfg.NamedTypes
var srcTemplate = template.Must(template.New("generate").Funcs(template.FuncMap{
"etTypeConverter": etTypeConverter,
"etTypeToStr": func(et binding.ExtendedType) string {
r, _ := extendedTypeToGo(et, ctr.NamedTypes)
return r
},
"toTypeName": toTypeName,
"cutPointer": cutPointer,
}).Parse(srcTmpl))
return srcTemplate.Execute(cfg.Output, ctr)
}
@ -295,31 +345,8 @@ func dropStdMethods(meths []manifest.Method, std *standard.Standard) []manifest.
return meths
}
func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]binding.Override) (string, string) {
over, ok := overrides[name]
if ok {
switch over.TypeName {
case "[]bool":
return "[]bool", ""
case "[]int", "[]uint", "[]int8", "[]uint8", "[]int16",
"[]uint16", "[]int32", "[]uint32", "[]int64", "[]uint64":
return "[]*big.Int", "math/big"
case "[][]byte":
return "[][]byte", ""
case "[]string":
return "[]string", ""
case "[]interop.Hash160":
return "[]util.Uint160", "github.com/nspcc-dev/neo-go/pkg/util"
case "[]interop.Hash256":
return "[]util.Uint256", "github.com/nspcc-dev/neo-go/pkg/util"
case "[]interop.PublicKey":
return "keys.PublicKeys", "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
case "[]interop.Signature":
return "[][]byte", ""
}
}
switch typ {
func extendedTypeToGo(et binding.ExtendedType, named map[string]binding.ExtendedType) (string, string) {
switch et.Base {
case smartcontract.AnyType:
return "interface{}", ""
case smartcontract.BoolType:
@ -339,16 +366,132 @@ func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]b
case smartcontract.SignatureType:
return "[]byte", ""
case smartcontract.ArrayType:
if len(et.Name) > 0 {
return "*" + toTypeName(et.Name), "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
} else if et.Value != nil {
if et.Value.Base == smartcontract.PublicKeyType { // Special array wrapper.
return "keys.PublicKeys", "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
}
sub, pkg := extendedTypeToGo(*et.Value, named)
return "[]" + sub, pkg
}
return "[]interface{}", ""
case smartcontract.MapType:
return "*stackitem.Map", "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
case smartcontract.InteropInterfaceType:
return "interface{}", ""
case smartcontract.VoidType:
return "", ""
default:
panic("unreachable")
}
panic("unreachable")
}
func etTypeConverter(et binding.ExtendedType, v string) string {
switch et.Base {
case smartcontract.AnyType:
return v + ".Value(), nil"
case smartcontract.BoolType:
return v + ".TryBool()"
case smartcontract.IntegerType:
return v + ".TryInteger()"
case smartcontract.ByteArrayType, smartcontract.SignatureType:
return v + ".TryBytes()"
case smartcontract.StringType:
return `func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (` + v + `)`
case smartcontract.Hash160Type:
return `func (item stackitem.Item) (util.Uint160, error) {
b, err := item.TryBytes()
if err != nil {
return util.Uint160{}, err
}
u, err := util.Uint160DecodeBytesBE(b)
if err != nil {
return util.Uint160{}, err
}
return u, nil
} (` + v + `)`
case smartcontract.Hash256Type:
return `func (item stackitem.Item) (util.Uint256, error) {
b, err := item.TryBytes()
if err != nil {
return util.Uint256{}, err
}
u, err := util.Uint256DecodeBytesBE(b)
if err != nil {
return util.Uint256{}, err
}
return u, nil
} (` + v + `)`
case smartcontract.PublicKeyType:
return `func (item stackitem.Item) (*keys.PublicKey, error) {
b, err := item.TryBytes()
if err != nil {
return nil, err
}
k, err := keys.NewPublicKeyFromBytes(b, elliptic.P256())
if err != nil {
return nil, err
}
return k, nil
} (` + v + `)`
case smartcontract.ArrayType:
if len(et.Name) > 0 {
return "itemTo" + toTypeName(et.Name) + "(" + v + ", nil)"
} else if et.Value != nil {
at, _ := extendedTypeToGo(et, nil)
return `func (item stackitem.Item) (` + at + `, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make(` + at + `, len(arr))
for i := range res {
res[i], err = ` + etTypeConverter(*et.Value, "arr[i]") + `
if err != nil {
return nil, err
}
}
return res, nil
} (` + v + `)`
}
return etTypeConverter(binding.ExtendedType{
Base: smartcontract.ArrayType,
Value: &binding.ExtendedType{
Base: smartcontract.AnyType,
},
}, v)
case smartcontract.MapType:
return `func (item stackitem.Item) (*stackitem.Map, error) {
if t := item.Type(); t != stackitem.MapT {
return nil, fmt.Errorf("%s is not a map", t.String())
}
return item.(*stackitem.Map), nil
} (` + v + `)`
case smartcontract.InteropInterfaceType:
return "item.Value(), nil"
case smartcontract.VoidType:
return ""
}
panic("unreachable")
}
func scTypeToGo(name string, typ smartcontract.ParamType, cfg *binding.Config) (string, string) {
et, ok := cfg.Types[name]
if !ok {
et = binding.ExtendedType{Base: typ}
}
return extendedTypeToGo(et, cfg.NamedTypes)
}
func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]struct{}) ContractTmpl {
@ -369,6 +512,27 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
}
}
}
for _, et := range cfg.NamedTypes {
for _, fet := range et.Fields {
_, pkg := extendedTypeToGo(fet.ExtendedType, ctr.NamedTypes)
if pkg != "" {
imports[pkg] = struct{}{}
}
// Additional packages used during decoding.
switch fet.Base {
case smartcontract.StringType:
imports["unicode/utf8"] = struct{}{}
case smartcontract.PublicKeyType:
imports["crypto/elliptic"] = struct{}{}
case smartcontract.MapType:
imports["fmt"] = struct{}{}
}
}
}
if len(cfg.NamedTypes) > 0 {
imports["errors"] = struct{}{}
}
// We're misusing CallFlag field for function name here.
for i := range ctr.SafeMethods {
switch ctr.SafeMethods[i].ReturnType {
@ -420,6 +584,8 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
ctr.SafeMethods[i].CallFlag = "ArrayOfUint256"
case "keys.PublicKeys":
ctr.SafeMethods[i].CallFlag = "ArrayOfPublicKeys"
default:
ctr.SafeMethods[i].CallFlag = ""
}
}
@ -446,3 +612,19 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
sort.Strings(ctr.Imports)
return ctr
}
func cutPointer(s string) string {
if s[0] == '*' {
return s[1:]
}
return s
}
func toTypeName(s string) string {
return strings.Map(func(c rune) rune {
if c == '.' {
return -1
}
return c
}, strings.ToUpper(s[0:1])+s[1:])
}