Merge pull request #2828 from nspcc-dev/rpcwrapper-structures

Handle structures in the RPC wrapper generator
This commit is contained in:
Roman Khimov 2022-12-06 21:40:16 +07:00 committed by GitHub
commit cad0fab704
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 2220 additions and 126 deletions

View file

@ -401,6 +401,7 @@ func TestAssistedRPCBindings(t *testing.T) {
}
checkBinding(filepath.Join("testdata", "types"))
checkBinding(filepath.Join("testdata", "structs"))
}
func TestGenerate_Errors(t *testing.T) {

View file

@ -0,0 +1,3 @@
name: "Types"
sourceurl: https://github.com/nspcc-dev/neo-go/
safemethods: ["contract", "block", "transaction", "struct"]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,39 @@
package structs
import (
"github.com/nspcc-dev/neo-go/pkg/interop"
"github.com/nspcc-dev/neo-go/pkg/interop/native/ledger"
"github.com/nspcc-dev/neo-go/pkg/interop/native/management"
)
type Internal struct {
Bool bool
Int int
Bytes []byte
String string
H160 interop.Hash160
H256 interop.Hash256
PK interop.PublicKey
PubKey interop.PublicKey
Sign interop.Signature
ArrOfBytes [][]byte
ArrOfH160 []interop.Hash160
Map map[int][]interop.PublicKey
Struct *Internal
}
func Contract(mc management.Contract) management.Contract {
return mc
}
func Block(b *ledger.Block) *ledger.Block {
return b
}
func Transaction(t *ledger.Transaction) *ledger.Transaction {
return t
}
func Struct(s *Internal) *Internal {
return s
}

View file

@ -1,3 +1,3 @@
name: "Types"
sourceurl: https://github.com/nspcc-dev/neo-go/
safemethods: ["bool", "int", "bytes", "string", "hash160", "hash256", "publicKey", "signature", "bools", "ints", "bytess", "strings", "hash160s", "hash256s", "publicKeys", "signatures"]
safemethods: ["bool", "int", "bytes", "string", "hash160", "hash256", "publicKey", "signature", "bools", "ints", "bytess", "strings", "hash160s", "hash256s", "publicKeys", "signatures", "aAAStrings", "maps", "crazyMaps"]

View file

@ -2,11 +2,15 @@
package types
import (
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/neorpc/result"
"github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"math/big"
"unicode/utf8"
)
// Hash contains contract hash.
@ -28,6 +32,64 @@ func NewReader(invoker Invoker) *ContractReader {
}
// AAAStrings invokes `aAAStrings` method of contract.
func (c *ContractReader) AAAStrings(s [][][]string) ([][][]string, error) {
return func (item stackitem.Item, err error) ([][][]string, error) {
if err != nil {
return nil, err
}
return func (item stackitem.Item) ([][][]string, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([][][]string, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) ([][]string, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([][]string, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) ([]string, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([]string, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (item)
} (unwrap.Item(c.invoker.Call(Hash, "aAAStrings", s)))
}
// Bool invokes `bool` method of contract.
func (c *ContractReader) Bool(b bool) (bool, error) {
return unwrap.Bool(c.invoker.Call(Hash, "bool", b))
@ -48,6 +110,97 @@ func (c *ContractReader) Bytess(b [][]byte) ([][]byte, error) {
return unwrap.ArrayOfBytes(c.invoker.Call(Hash, "bytess", b))
}
// CrazyMaps invokes `crazyMaps` method of contract.
func (c *ContractReader) CrazyMaps(m map[*big.Int][]map[string][]util.Uint160) (map[*big.Int][]map[string][]util.Uint160, error) {
return func (item stackitem.Item, err error) (map[*big.Int][]map[string][]util.Uint160, error) {
if err != nil {
return nil, err
}
return func (item stackitem.Item) (map[*big.Int][]map[string][]util.Uint160, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(map[*big.Int][]map[string][]util.Uint160)
for i := range m {
k, err := m[i].Key.TryInteger()
if err != nil {
return nil, fmt.Errorf("key %d: %w", i, err)
}
v, err := func (item stackitem.Item) ([]map[string][]util.Uint160, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([]map[string][]util.Uint160, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) (map[string][]util.Uint160, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(map[string][]util.Uint160)
for i := range m {
k, err := func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (m[i].Key)
if err != nil {
return nil, fmt.Errorf("key %d: %w", i, err)
}
v, err := func (item stackitem.Item) ([]util.Uint160, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([]util.Uint160, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) (util.Uint160, error) {
b, err := item.TryBytes()
if err != nil {
return util.Uint160{}, err
}
u, err := util.Uint160DecodeBytesBE(b)
if err != nil {
return util.Uint160{}, err
}
return u, nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (m[i].Value)
if err != nil {
return nil, fmt.Errorf("value %d: %w", i, err)
}
res[k] = v
}
return res, nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (m[i].Value)
if err != nil {
return nil, fmt.Errorf("value %d: %w", i, err)
}
res[k] = v
}
return res, nil
} (item)
} (unwrap.Item(c.invoker.Call(Hash, "crazyMaps", m)))
}
// Hash160 invokes `hash160` method of contract.
func (c *ContractReader) Hash160(h util.Uint160) (util.Uint160, error) {
return unwrap.Uint160(c.invoker.Call(Hash, "hash160", h))
@ -78,6 +231,52 @@ func (c *ContractReader) Ints(i []*big.Int) ([]*big.Int, error) {
return unwrap.ArrayOfBigInts(c.invoker.Call(Hash, "ints", i))
}
// Maps invokes `maps` method of contract.
func (c *ContractReader) Maps(m map[string]string) (map[string]string, error) {
return func (item stackitem.Item, err error) (map[string]string, error) {
if err != nil {
return nil, err
}
return func (item stackitem.Item) (map[string]string, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(map[string]string)
for i := range m {
k, err := func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (m[i].Key)
if err != nil {
return nil, fmt.Errorf("key %d: %w", i, err)
}
v, err := func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (m[i].Value)
if err != nil {
return nil, fmt.Errorf("value %d: %w", i, err)
}
res[k] = v
}
return res, nil
} (item)
} (unwrap.Item(c.invoker.Call(Hash, "maps", m)))
}
// PublicKey invokes `publicKey` method of contract.
func (c *ContractReader) PublicKey(k *keys.PublicKey) (*keys.PublicKey, error) {
return unwrap.PublicKey(c.invoker.Call(Hash, "publicKey", k))

View file

@ -67,3 +67,15 @@ func PublicKeys(k []interop.PublicKey) []interop.PublicKey {
func Signatures(s []interop.Signature) []interop.Signature {
return nil
}
func AAAStrings(s [][][]string) [][][]string {
return s
}
func Maps(m map[string]string) map[string]string {
return m
}
func CrazyMaps(m map[int][]map[string][]interop.Hash160) map[int][]map[string][]interop.Hash160 {
return m
}

View file

@ -459,8 +459,9 @@ result. This pair can then be used in Invoker `TraverseIterator` method to
retrieve actual resulting items.
Go contracts can also make use of additional type data from bindings
configuration file generated during compilation. At the moment it allows to
generate proper wrappers for simple array types, but doesn't cover structures:
configuration file generated during compilation. This can cover arrays, maps
and structures. Notice that structured types returned by methods can't be Null
at the moment (see #2795).
```
$ ./bin/neo-go contract compile -i contract.go --config contract.yml -o contract.nef --manifest manifest.json --bindings contract.bindings.yml

View file

@ -914,7 +914,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.convertMap(n)
default:
if tn, ok := t.(*types.Named); ok && isInteropPath(tn.String()) {
st, _, _ := scAndVMInteropTypeFromExpr(tn, false)
st, _, _, _ := scAndVMInteropTypeFromExpr(tn, false)
expectedLen := -1
switch st {
case smartcontract.Hash160Type:

View file

@ -287,14 +287,27 @@ func CompileAndSave(src string, o *Options) ([]byte, error) {
cfg := binding.NewConfig()
cfg.Package = di.MainPkg
for _, m := range di.Methods {
if !m.IsExported {
continue
}
for _, p := range m.Parameters {
pname := m.Name.Name + "." + p.Name
if p.RealType.TypeName != "" {
cfg.Overrides[m.Name.Name+"."+p.Name] = p.RealType
cfg.Overrides[pname] = p.RealType
}
if p.ExtendedType != nil {
cfg.Types[pname] = *p.ExtendedType
}
}
if m.ReturnTypeReal.TypeName != "" {
cfg.Overrides[m.Name.Name] = m.ReturnTypeReal
}
if m.ReturnTypeExtended != nil {
cfg.Types[m.Name.Name] = *m.ReturnTypeExtended
}
}
if len(di.NamedTypes) > 0 {
cfg.NamedTypes = di.NamedTypes
}
data, err := yaml.Marshal(&cfg)
if err != nil {

View file

@ -26,6 +26,9 @@ type DebugInfo struct {
Hash util.Uint160 `json:"hash"`
Documents []string `json:"documents"`
Methods []MethodDebugInfo `json:"methods"`
// NamedTypes are exported structured types that have some name (even
// if the original structure doesn't) and a number of internal fields.
NamedTypes map[string]binding.ExtendedType `json:"-"`
Events []EventDebugInfo `json:"events"`
// EmittedEvents contains events occurring in code.
EmittedEvents map[string][][]string `json:"-"`
@ -55,6 +58,8 @@ type MethodDebugInfo struct {
ReturnType string `json:"return"`
// ReturnTypeReal is the method's return type as specified in Go code.
ReturnTypeReal binding.Override `json:"-"`
// ReturnTypeExtended is the method's return type with additional data.
ReturnTypeExtended *binding.ExtendedType `json:"-"`
// ReturnTypeSC is a return type to use in manifest.
ReturnTypeSC smartcontract.ParamType `json:"-"`
Variables []string `json:"variables"`
@ -103,6 +108,7 @@ type DebugParam struct {
Name string `json:"name"`
Type string `json:"type"`
RealType binding.Override `json:"-"`
ExtendedType *binding.ExtendedType `json:"-"`
TypeSC smartcontract.ParamType `json:"-"`
}
@ -185,8 +191,9 @@ func (c *codegen) emitDebugInfo(contract []byte) *DebugInfo {
}
start := len(d.Methods)
d.NamedTypes = make(map[string]binding.ExtendedType)
for name, scope := range c.funcs {
m := c.methodInfoFromScope(name, scope)
m := c.methodInfoFromScope(name, scope, d.NamedTypes)
if m.Range.Start == m.Range.End {
continue
}
@ -201,7 +208,7 @@ func (c *codegen) emitDebugInfo(contract []byte) *DebugInfo {
}
func (c *codegen) registerDebugVariable(name string, expr ast.Expr) {
_, vt, _ := c.scAndVMTypeFromExpr(expr)
_, vt, _, _ := c.scAndVMTypeFromExpr(expr, nil)
if c.scope == nil {
c.staticVariables = append(c.staticVariables, name+","+vt.String())
return
@ -209,15 +216,16 @@ func (c *codegen) registerDebugVariable(name string, expr ast.Expr) {
c.scope.variables = append(c.scope.variables, name+","+vt.String())
}
func (c *codegen) methodInfoFromScope(name string, scope *funcScope) *MethodDebugInfo {
func (c *codegen) methodInfoFromScope(name string, scope *funcScope, exts map[string]binding.ExtendedType) *MethodDebugInfo {
ps := scope.decl.Type.Params
params := make([]DebugParam, 0, ps.NumFields())
for i := range ps.List {
for j := range ps.List[i].Names {
st, vt, rt := c.scAndVMTypeFromExpr(ps.List[i].Type)
st, vt, rt, et := c.scAndVMTypeFromExpr(ps.List[i].Type, exts)
params = append(params, DebugParam{
Name: ps.List[i].Names[j].Name,
Type: vt.String(),
ExtendedType: et,
RealType: rt,
TypeSC: st,
})
@ -226,7 +234,7 @@ func (c *codegen) methodInfoFromScope(name string, scope *funcScope) *MethodDebu
ss := strings.Split(name, ".")
name = ss[len(ss)-1]
r, n := utf8.DecodeRuneInString(name)
st, vt, rt := c.scAndVMReturnTypeFromScope(scope)
st, vt, rt, et := c.scAndVMReturnTypeFromScope(scope, exts)
return &MethodDebugInfo{
ID: name,
@ -239,6 +247,7 @@ func (c *codegen) methodInfoFromScope(name string, scope *funcScope) *MethodDebu
Range: scope.rng,
Parameters: params,
ReturnType: vt,
ReturnTypeExtended: et,
ReturnTypeReal: rt,
ReturnTypeSC: st,
SeqPoints: c.sequencePoints[name],
@ -246,33 +255,39 @@ func (c *codegen) methodInfoFromScope(name string, scope *funcScope) *MethodDebu
}
}
func (c *codegen) scAndVMReturnTypeFromScope(scope *funcScope) (smartcontract.ParamType, string, binding.Override) {
func (c *codegen) scAndVMReturnTypeFromScope(scope *funcScope, exts map[string]binding.ExtendedType) (smartcontract.ParamType, string, binding.Override, *binding.ExtendedType) {
results := scope.decl.Type.Results
switch results.NumFields() {
case 0:
return smartcontract.VoidType, "Void", binding.Override{}
return smartcontract.VoidType, "Void", binding.Override{}, nil
case 1:
st, vt, s := c.scAndVMTypeFromExpr(results.List[0].Type)
return st, vt.String(), s
st, vt, s, et := c.scAndVMTypeFromExpr(results.List[0].Type, exts)
return st, vt.String(), s, et
default:
// multiple return values are not supported in debugger
return smartcontract.AnyType, "Any", binding.Override{}
return smartcontract.AnyType, "Any", binding.Override{}, nil
}
}
func scAndVMInteropTypeFromExpr(named *types.Named, isPointer bool) (smartcontract.ParamType, stackitem.Type, binding.Override) {
func scAndVMInteropTypeFromExpr(named *types.Named, isPointer bool) (smartcontract.ParamType, stackitem.Type, binding.Override, *binding.ExtendedType) {
name := named.Obj().Name()
pkg := named.Obj().Pkg().Name()
switch pkg {
case "ledger", "contract":
case "ledger", "management":
switch name {
case "ParameterType", "SignerScope", "WitnessAction", "WitnessConditionType", "VMState":
return smartcontract.IntegerType, stackitem.IntegerT, binding.Override{TypeName: "int"}, nil
}
// Block, Transaction, Contract.
typeName := pkg + "." + name
et := &binding.ExtendedType{Base: smartcontract.ArrayType, Name: typeName}
if isPointer {
typeName = "*" + typeName
}
return smartcontract.ArrayType, stackitem.ArrayT, binding.Override{
Package: named.Obj().Pkg().Path(),
TypeName: typeName,
} // Block, Transaction, Contract
}, et
case "interop":
if name != "Interface" {
over := binding.Override{
@ -281,26 +296,29 @@ func scAndVMInteropTypeFromExpr(named *types.Named, isPointer bool) (smartcontra
}
switch name {
case "Hash160":
return smartcontract.Hash160Type, stackitem.ByteArrayT, over
return smartcontract.Hash160Type, stackitem.ByteArrayT, over, nil
case "Hash256":
return smartcontract.Hash256Type, stackitem.ByteArrayT, over
return smartcontract.Hash256Type, stackitem.ByteArrayT, over, nil
case "PublicKey":
return smartcontract.PublicKeyType, stackitem.ByteArrayT, over
return smartcontract.PublicKeyType, stackitem.ByteArrayT, over, nil
case "Signature":
return smartcontract.SignatureType, stackitem.ByteArrayT, over
return smartcontract.SignatureType, stackitem.ByteArrayT, over, nil
}
}
}
return smartcontract.InteropInterfaceType, stackitem.InteropT, binding.Override{TypeName: "interface{}"}
return smartcontract.InteropInterfaceType,
stackitem.InteropT,
binding.Override{TypeName: "interface{}"},
&binding.ExtendedType{Base: smartcontract.InteropInterfaceType, Interface: "iterator"} // Temporarily all interops are iterators.
}
func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr) (smartcontract.ParamType, stackitem.Type, binding.Override) {
return c.scAndVMTypeFromType(c.typeOf(typ))
func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr, exts map[string]binding.ExtendedType) (smartcontract.ParamType, stackitem.Type, binding.Override, *binding.ExtendedType) {
return c.scAndVMTypeFromType(c.typeOf(typ), exts)
}
func (c *codegen) scAndVMTypeFromType(t types.Type) (smartcontract.ParamType, stackitem.Type, binding.Override) {
func (c *codegen) scAndVMTypeFromType(t types.Type, exts map[string]binding.ExtendedType) (smartcontract.ParamType, stackitem.Type, binding.Override, *binding.ExtendedType) {
if t == nil {
return smartcontract.AnyType, stackitem.AnyT, binding.Override{TypeName: "interface{}"}
return smartcontract.AnyType, stackitem.AnyT, binding.Override{TypeName: "interface{}"}, nil
}
var isPtr bool
@ -314,10 +332,16 @@ func (c *codegen) scAndVMTypeFromType(t types.Type) (smartcontract.ParamType, st
}
if isNamed {
if isInteropPath(named.String()) {
return scAndVMInteropTypeFromExpr(named, isPtr)
st, vt, over, et := scAndVMInteropTypeFromExpr(named, isPtr)
if et != nil && et.Base == smartcontract.ArrayType && exts != nil && exts[et.Name].Name != et.Name {
_ = c.genStructExtended(named.Underlying().(*types.Struct), et.Name, exts)
}
return st, vt, over, et
}
}
if ptr, isPtr := t.(*types.Pointer); isPtr {
t = ptr.Elem()
}
var over binding.Override
switch t := t.Underlying().(type) {
case *types.Basic:
@ -325,43 +349,103 @@ func (c *codegen) scAndVMTypeFromType(t types.Type) (smartcontract.ParamType, st
switch {
case info&types.IsInteger != 0:
over.TypeName = "int"
return smartcontract.IntegerType, stackitem.IntegerT, over
return smartcontract.IntegerType, stackitem.IntegerT, over, nil
case info&types.IsBoolean != 0:
over.TypeName = "bool"
return smartcontract.BoolType, stackitem.BooleanT, over
return smartcontract.BoolType, stackitem.BooleanT, over, nil
case info&types.IsString != 0:
over.TypeName = "string"
return smartcontract.StringType, stackitem.ByteArrayT, over
return smartcontract.StringType, stackitem.ByteArrayT, over, nil
default:
over.TypeName = "interface{}"
return smartcontract.AnyType, stackitem.AnyT, over
return smartcontract.AnyType, stackitem.AnyT, over, nil
}
case *types.Map:
_, _, over := c.scAndVMTypeFromType(t.Elem())
et := &binding.ExtendedType{
Base: smartcontract.MapType,
}
et.Key, _, _, _ = c.scAndVMTypeFromType(t.Key(), exts)
vt, _, over, vet := c.scAndVMTypeFromType(t.Elem(), exts)
et.Value = vet
if et.Value == nil {
et.Value = &binding.ExtendedType{Base: vt}
}
over.TypeName = "map[" + t.Key().String() + "]" + over.TypeName
return smartcontract.MapType, stackitem.MapT, over
return smartcontract.MapType, stackitem.MapT, over, et
case *types.Struct:
if isNamed {
over.Package = named.Obj().Pkg().Path()
over.TypeName = named.Obj().Pkg().Name() + "." + named.Obj().Name()
_ = c.genStructExtended(t, over.TypeName, exts)
} else {
name := "unnamed"
if exts != nil {
for exts[name].Name == name {
name = name + "X"
}
return smartcontract.ArrayType, stackitem.StructT, over
_ = c.genStructExtended(t, name, exts)
}
}
return smartcontract.ArrayType, stackitem.StructT, over,
&binding.ExtendedType{ // Value-less, refer to exts.
Base: smartcontract.ArrayType,
Name: over.TypeName,
}
case *types.Slice:
if isByte(t.Elem()) {
over.TypeName = "[]byte"
return smartcontract.ByteArrayType, stackitem.ByteArrayT, over
return smartcontract.ByteArrayType, stackitem.ByteArrayT, over, nil
}
et := &binding.ExtendedType{
Base: smartcontract.ArrayType,
}
vt, _, over, vet := c.scAndVMTypeFromType(t.Elem(), exts)
et.Value = vet
if et.Value == nil {
et.Value = &binding.ExtendedType{
Base: vt,
}
}
_, _, over := c.scAndVMTypeFromType(t.Elem())
if over.TypeName != "" {
over.TypeName = "[]" + over.TypeName
}
return smartcontract.ArrayType, stackitem.ArrayT, over
return smartcontract.ArrayType, stackitem.ArrayT, over, et
default:
over.TypeName = "interface{}"
return smartcontract.AnyType, stackitem.AnyT, over
return smartcontract.AnyType, stackitem.AnyT, over, nil
}
}
func (c *codegen) genStructExtended(t *types.Struct, name string, exts map[string]binding.ExtendedType) *binding.ExtendedType {
var et *binding.ExtendedType
if exts != nil {
if exts[name].Name != name {
et = &binding.ExtendedType{
Base: smartcontract.ArrayType,
Name: name,
Fields: make([]binding.FieldExtendedType, t.NumFields()),
}
exts[name] = *et // Prefill to solve recursive structures.
for i := range et.Fields {
field := t.Field(i)
ft, _, _, fet := c.scAndVMTypeFromType(field.Type(), exts)
if fet == nil {
et.Fields[i].ExtendedType.Base = ft
} else {
et.Fields[i].ExtendedType = *fet
}
et.Fields[i].Field = field.Name()
}
exts[name] = *et // Set real structure data.
} else {
et = new(binding.ExtendedType)
*et = exts[name]
}
}
return et
}
// MarshalJSON implements the json.Marshaler interface.
func (d *DebugRange) MarshalJSON() ([]byte, error) {
return []byte(`"` + strconv.FormatUint(uint64(d.Start), 10) + `-` +

View file

@ -175,7 +175,7 @@ func (c *codegen) processNotify(f *funcScope, args []ast.Expr, hasEllipsis bool)
params := make([]string, 0, len(args[1:]))
vParams := make([]*stackitem.Type, 0, len(args[1:]))
for _, p := range args[1:] {
st, vt, _ := c.scAndVMTypeFromExpr(p)
st, vt, _, _ := c.scAndVMTypeFromExpr(p, nil)
params = append(params, st.String())
vParams = append(vParams, &vt)
}

View file

@ -53,9 +53,25 @@ type (
Hash util.Uint160 `yaml:"hash,omitempty"`
Overrides map[string]Override `yaml:"overrides,omitempty"`
CallFlags map[string]callflag.CallFlag `yaml:"callflags,omitempty"`
NamedTypes map[string]ExtendedType `yaml:"namedtypes,omitempty"`
Types map[string]ExtendedType `yaml:"types,omitempty"`
Output io.Writer `yaml:"-"`
}
ExtendedType struct {
Base smartcontract.ParamType `yaml:"base"`
Name string `yaml:"name,omitempty"` // Structure name, omitted for arrays, interfaces and maps.
Interface string `yaml:"interface,omitempty"` // Interface type name, "iterator" only for now.
Key smartcontract.ParamType `yaml:"key,omitempty"` // Key type (only simple types can be used for keys) for maps.
Value *ExtendedType `yaml:"value,omitempty"` // Value type for iterators and arrays.
Fields []FieldExtendedType `yaml:"fields,omitempty"` // Ordered type data for structure fields.
}
FieldExtendedType struct {
Field string `yaml:"field"`
ExtendedType `yaml:",inline"`
}
ContractTmpl struct {
PackageName string
ContractName string
@ -86,6 +102,8 @@ func NewConfig() Config {
return Config{
Overrides: make(map[string]Override),
CallFlags: make(map[string]callflag.CallFlag),
NamedTypes: make(map[string]ExtendedType),
Types: make(map[string]ExtendedType),
}
}
@ -101,8 +119,8 @@ func Generate(cfg Config) error {
return srcTemplate.Execute(cfg.Output, ctr)
}
func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]Override) (string, string) {
if over, ok := overrides[name]; ok {
func scTypeToGo(name string, typ smartcontract.ParamType, cfg *Config) (string, string) {
if over, ok := cfg.Overrides[name]; ok {
return over.TypeName, over.Package
}
@ -141,7 +159,7 @@ func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]O
// TemplateFromManifest create a contract template using the given configuration
// and type conversion function. It assumes manifest to be present in the
// configuration and assumes it to be correct (passing IsValid check).
func TemplateFromManifest(cfg Config, scTypeConverter func(string, smartcontract.ParamType, map[string]Override) (string, string)) ContractTmpl {
func TemplateFromManifest(cfg Config, scTypeConverter func(string, smartcontract.ParamType, *Config) (string, string)) ContractTmpl {
hStr := ""
for _, b := range cfg.Hash.BytesBE() {
hStr += fmt.Sprintf("\\x%02x", b)
@ -203,7 +221,7 @@ func TemplateFromManifest(cfg Config, scTypeConverter func(string, smartcontract
var varnames = make(map[string]bool)
for i := range m.Parameters {
name := m.Parameters[i].Name
typeStr, pkg := scTypeConverter(m.Name+"."+name, m.Parameters[i].Type, cfg.Overrides)
typeStr, pkg := scTypeConverter(m.Name+"."+name, m.Parameters[i].Type, &cfg)
if pkg != "" {
imports[pkg] = struct{}{}
}
@ -220,7 +238,7 @@ func TemplateFromManifest(cfg Config, scTypeConverter func(string, smartcontract
})
}
typeStr, pkg := scTypeConverter(m.Name, m.ReturnType, cfg.Overrides)
typeStr, pkg := scTypeConverter(m.Name, m.ReturnType, &cfg)
if pkg != "" {
imports[pkg] = struct{}{}
}

View file

@ -3,6 +3,7 @@ package rpcbinding
import (
"fmt"
"sort"
"strings"
"text/template"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
@ -18,14 +19,20 @@ func (c *ContractReader) {{.Name}}({{range $index, $arg := .Arguments -}}
{{- if ne $index 0}}, {{end}}
{{- .Name}} {{.Type}}
{{- end}}) {{if .ReturnType }}({{ .ReturnType }}, error) {
return unwrap.{{.CallFlag}}(c.invoker.Call(Hash, "{{ .NameABI }}"{{/* CallFlag field is used for function name */}}
{{- range $arg := .Arguments -}}, {{.Name}}{{end}}))
return {{if and (not .ItemTo) (eq .Unwrapper "Item")}}func (item stackitem.Item, err error) ({{ .ReturnType }}, error) {
if err != nil {
return nil, err
}
return {{addIndent (etTypeConverter .ExtendedReturn "item") "\t"}}
} ( {{- end -}} {{if .ItemTo -}} itemTo{{ .ItemTo }}( {{- end -}}
unwrap.{{.Unwrapper}}(c.invoker.Call(Hash, "{{ .NameABI }}"
{{- range $arg := .Arguments -}}, {{.Name}}{{end -}} )) {{- if or .ItemTo (eq .Unwrapper "Item") -}} ) {{- end}}
{{- else -}} (*result.Invoke, error) {
c.invoker.Call(Hash, "{{ .NameABI }}"
{{- range $arg := .Arguments -}}, {{.Name}}{{end}})
{{- end}}
}
{{- if eq .CallFlag "SessionIterator"}}
{{- if eq .Unwrapper "SessionIterator"}}
// {{.Name}}Expanded is similar to {{.Name}} (uses the same contract
// method), but can be useful if the server used doesn't support sessions and
@ -101,6 +108,14 @@ import (
// Hash contains contract hash.
var Hash = {{ .Hash }}
{{range $name, $typ := .NamedTypes}}
// {{toTypeName $name}} is a contract-specific {{$name}} type used by its methods.
type {{toTypeName $name}} struct {
{{- range $m := $typ.Fields}}
{{.Field}} {{etTypeToStr .ExtendedType}}
{{- end}}
}
{{end -}}
{{if .HasReader}}// Invoker is used by ContractReader to call various safe methods.
type Invoker interface {
{{if or .IsNep11D .IsNep11ND}} nep11.Invoker
@ -199,15 +214,41 @@ func New(actor Actor) *Contract {
{{end}}
{{- range $m := .Methods}}
{{template "METHOD" $m }}
{{end}}`
{{end}}
{{- range $name, $typ := .NamedTypes}}
// itemTo{{toTypeName $name}} converts stack item into *{{toTypeName $name}}.
func itemTo{{toTypeName $name}}(item stackitem.Item, err error) (*{{toTypeName $name}}, error) {
if err != nil {
return nil, err
}
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
if len(arr) != {{len $typ.Fields}} {
return nil, errors.New("wrong number of structure elements")
}
var srcTemplate = template.Must(template.New("generate").Parse(srcTmpl))
var res = new({{toTypeName $name}})
{{if len .Fields}} var index = -1
{{- range $m := $typ.Fields}}
index++
res.{{.Field}}, err = {{etTypeConverter .ExtendedType "arr[index]"}}
if err != nil {
return nil, fmt.Errorf("field {{.Field}}: %w", err)
}
{{end}}
{{end}}
return res, err
}
{{end}}`
type (
ContractTmpl struct {
binding.ContractTmpl
SafeMethods []binding.MethodTmpl
SafeMethods []SafeMethodTmpl
NamedTypes map[string]binding.ExtendedType
IsNep11D bool
IsNep11ND bool
@ -217,6 +258,13 @@ type (
HasWriter bool
HasIterator bool
}
SafeMethodTmpl struct {
binding.MethodTmpl
Unwrapper string
ItemTo string
ExtendedReturn binding.ExtendedType
}
)
// NewConfig initializes and returns a new config instance.
@ -268,6 +316,18 @@ func Generate(cfg binding.Config) error {
ctr.ContractTmpl = binding.TemplateFromManifest(cfg, scTypeToGo)
ctr = scTemplateToRPC(cfg, ctr, imports)
ctr.NamedTypes = cfg.NamedTypes
var srcTemplate = template.Must(template.New("generate").Funcs(template.FuncMap{
"addIndent": addIndent,
"etTypeConverter": etTypeConverter,
"etTypeToStr": func(et binding.ExtendedType) string {
r, _ := extendedTypeToGo(et, ctr.NamedTypes)
return r
},
"toTypeName": toTypeName,
"cutPointer": cutPointer,
}).Parse(srcTmpl))
return srcTemplate.Execute(cfg.Output, ctr)
}
@ -295,31 +355,8 @@ func dropStdMethods(meths []manifest.Method, std *standard.Standard) []manifest.
return meths
}
func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]binding.Override) (string, string) {
over, ok := overrides[name]
if ok {
switch over.TypeName {
case "[]bool":
return "[]bool", ""
case "[]int", "[]uint", "[]int8", "[]uint8", "[]int16",
"[]uint16", "[]int32", "[]uint32", "[]int64", "[]uint64":
return "[]*big.Int", "math/big"
case "[][]byte":
return "[][]byte", ""
case "[]string":
return "[]string", ""
case "[]interop.Hash160":
return "[]util.Uint160", "github.com/nspcc-dev/neo-go/pkg/util"
case "[]interop.Hash256":
return "[]util.Uint256", "github.com/nspcc-dev/neo-go/pkg/util"
case "[]interop.PublicKey":
return "keys.PublicKeys", "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
case "[]interop.Signature":
return "[][]byte", ""
}
}
switch typ {
func extendedTypeToGo(et binding.ExtendedType, named map[string]binding.ExtendedType) (string, string) {
switch et.Base {
case smartcontract.AnyType:
return "interface{}", ""
case smartcontract.BoolType:
@ -339,16 +376,148 @@ func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]b
case smartcontract.SignatureType:
return "[]byte", ""
case smartcontract.ArrayType:
if len(et.Name) > 0 {
return "*" + toTypeName(et.Name), "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
} else if et.Value != nil {
if et.Value.Base == smartcontract.PublicKeyType { // Special array wrapper.
return "keys.PublicKeys", "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
}
sub, pkg := extendedTypeToGo(*et.Value, named)
return "[]" + sub, pkg
}
return "[]interface{}", ""
case smartcontract.MapType:
return "*stackitem.Map", "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
kt, _ := extendedTypeToGo(binding.ExtendedType{Base: et.Key}, named)
vt, _ := extendedTypeToGo(*et.Value, named)
return "map[" + kt + "]" + vt, "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
case smartcontract.InteropInterfaceType:
return "interface{}", ""
case smartcontract.VoidType:
return "", ""
default:
}
panic("unreachable")
}
func etTypeConverter(et binding.ExtendedType, v string) string {
switch et.Base {
case smartcontract.AnyType:
return v + ".Value(), nil"
case smartcontract.BoolType:
return v + ".TryBool()"
case smartcontract.IntegerType:
return v + ".TryInteger()"
case smartcontract.ByteArrayType, smartcontract.SignatureType:
return v + ".TryBytes()"
case smartcontract.StringType:
return `func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (` + v + `)`
case smartcontract.Hash160Type:
return `func (item stackitem.Item) (util.Uint160, error) {
b, err := item.TryBytes()
if err != nil {
return util.Uint160{}, err
}
u, err := util.Uint160DecodeBytesBE(b)
if err != nil {
return util.Uint160{}, err
}
return u, nil
} (` + v + `)`
case smartcontract.Hash256Type:
return `func (item stackitem.Item) (util.Uint256, error) {
b, err := item.TryBytes()
if err != nil {
return util.Uint256{}, err
}
u, err := util.Uint256DecodeBytesBE(b)
if err != nil {
return util.Uint256{}, err
}
return u, nil
} (` + v + `)`
case smartcontract.PublicKeyType:
return `func (item stackitem.Item) (*keys.PublicKey, error) {
b, err := item.TryBytes()
if err != nil {
return nil, err
}
k, err := keys.NewPublicKeyFromBytes(b, elliptic.P256())
if err != nil {
return nil, err
}
return k, nil
} (` + v + `)`
case smartcontract.ArrayType:
if len(et.Name) > 0 {
return "itemTo" + toTypeName(et.Name) + "(" + v + ", nil)"
} else if et.Value != nil {
at, _ := extendedTypeToGo(et, nil)
return `func (item stackitem.Item) (` + at + `, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make(` + at + `, len(arr))
for i := range res {
res[i], err = ` + addIndent(etTypeConverter(*et.Value, "arr[i]"), "\t\t") + `
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (` + v + `)`
}
return etTypeConverter(binding.ExtendedType{
Base: smartcontract.ArrayType,
Value: &binding.ExtendedType{
Base: smartcontract.AnyType,
},
}, v)
case smartcontract.MapType:
at, _ := extendedTypeToGo(et, nil)
return `func (item stackitem.Item) (` + at + `, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(` + at + `)
for i := range m {
k, err := ` + addIndent(etTypeConverter(binding.ExtendedType{Base: et.Key}, "m[i].Key"), "\t\t") + `
if err != nil {
return nil, fmt.Errorf("key %d: %w", i, err)
}
v, err := ` + addIndent(etTypeConverter(*et.Value, "m[i].Value"), "\t\t") + `
if err != nil {
return nil, fmt.Errorf("value %d: %w", i, err)
}
res[k] = v
}
return res, nil
} (` + v + `)`
case smartcontract.InteropInterfaceType:
return "item.Value(), nil"
case smartcontract.VoidType:
return ""
}
panic("unreachable")
}
func scTypeToGo(name string, typ smartcontract.ParamType, cfg *binding.Config) (string, string) {
et, ok := cfg.Types[name]
if !ok {
et = binding.ExtendedType{Base: typ}
}
return extendedTypeToGo(et, cfg.NamedTypes)
}
func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]struct{}) ContractTmpl {
@ -359,7 +528,14 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
for i := 0; i < len(ctr.Methods); i++ {
abim := cfg.Manifest.ABI.GetMethod(ctr.Methods[i].NameABI, len(ctr.Methods[i].Arguments))
if abim.Safe {
ctr.SafeMethods = append(ctr.SafeMethods, ctr.Methods[i])
ctr.SafeMethods = append(ctr.SafeMethods, SafeMethodTmpl{MethodTmpl: ctr.Methods[i]})
et, ok := cfg.Types[abim.Name]
if ok {
ctr.SafeMethods[len(ctr.SafeMethods)-1].ExtendedReturn = et
if abim.ReturnType == smartcontract.ArrayType && len(et.Name) > 0 {
ctr.SafeMethods[len(ctr.SafeMethods)-1].ItemTo = cutPointer(ctr.Methods[i].ReturnType)
}
}
ctr.Methods = append(ctr.Methods[:i], ctr.Methods[i+1:]...)
i--
} else {
@ -369,7 +545,13 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
}
}
}
// We're misusing CallFlag field for function name here.
for _, et := range cfg.NamedTypes {
addETImports(et, ctr.NamedTypes, imports)
}
if len(cfg.NamedTypes) > 0 {
imports["errors"] = struct{}{}
}
for i := range ctr.SafeMethods {
switch ctr.SafeMethods[i].ReturnType {
case "interface{}":
@ -379,47 +561,50 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{}
imports["github.com/nspcc-dev/neo-go/pkg/neorpc/result"] = struct{}{}
ctr.SafeMethods[i].ReturnType = "uuid.UUID, result.Iterator"
ctr.SafeMethods[i].CallFlag = "SessionIterator"
ctr.SafeMethods[i].Unwrapper = "SessionIterator"
ctr.HasIterator = true
} else {
imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{}
ctr.SafeMethods[i].ReturnType = "stackitem.Item"
ctr.SafeMethods[i].CallFlag = "Item"
ctr.SafeMethods[i].Unwrapper = "Item"
}
case "bool":
ctr.SafeMethods[i].CallFlag = "Bool"
ctr.SafeMethods[i].Unwrapper = "Bool"
case "*big.Int":
ctr.SafeMethods[i].CallFlag = "BigInt"
ctr.SafeMethods[i].Unwrapper = "BigInt"
case "string":
ctr.SafeMethods[i].CallFlag = "UTF8String"
ctr.SafeMethods[i].Unwrapper = "UTF8String"
case "util.Uint160":
ctr.SafeMethods[i].CallFlag = "Uint160"
ctr.SafeMethods[i].Unwrapper = "Uint160"
case "util.Uint256":
ctr.SafeMethods[i].CallFlag = "Uint256"
ctr.SafeMethods[i].Unwrapper = "Uint256"
case "*keys.PublicKey":
ctr.SafeMethods[i].CallFlag = "PublicKey"
ctr.SafeMethods[i].Unwrapper = "PublicKey"
case "[]byte":
ctr.SafeMethods[i].CallFlag = "Bytes"
ctr.SafeMethods[i].Unwrapper = "Bytes"
case "[]interface{}":
imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{}
ctr.SafeMethods[i].ReturnType = "[]stackitem.Item"
ctr.SafeMethods[i].CallFlag = "Array"
ctr.SafeMethods[i].Unwrapper = "Array"
case "*stackitem.Map":
ctr.SafeMethods[i].CallFlag = "Map"
ctr.SafeMethods[i].Unwrapper = "Map"
case "[]bool":
ctr.SafeMethods[i].CallFlag = "ArrayOfBools"
ctr.SafeMethods[i].Unwrapper = "ArrayOfBools"
case "[]*big.Int":
ctr.SafeMethods[i].CallFlag = "ArrayOfBigInts"
ctr.SafeMethods[i].Unwrapper = "ArrayOfBigInts"
case "[][]byte":
ctr.SafeMethods[i].CallFlag = "ArrayOfBytes"
ctr.SafeMethods[i].Unwrapper = "ArrayOfBytes"
case "[]string":
ctr.SafeMethods[i].CallFlag = "ArrayOfUTF8Strings"
ctr.SafeMethods[i].Unwrapper = "ArrayOfUTF8Strings"
case "[]util.Uint160":
ctr.SafeMethods[i].CallFlag = "ArrayOfUint160"
ctr.SafeMethods[i].Unwrapper = "ArrayOfUint160"
case "[]util.Uint256":
ctr.SafeMethods[i].CallFlag = "ArrayOfUint256"
ctr.SafeMethods[i].Unwrapper = "ArrayOfUint256"
case "keys.PublicKeys":
ctr.SafeMethods[i].CallFlag = "ArrayOfPublicKeys"
ctr.SafeMethods[i].Unwrapper = "ArrayOfPublicKeys"
default:
addETImports(ctr.SafeMethods[i].ExtendedReturn, ctr.NamedTypes, imports)
ctr.SafeMethods[i].Unwrapper = "Item"
}
}
@ -446,3 +631,52 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
sort.Strings(ctr.Imports)
return ctr
}
func addETImports(et binding.ExtendedType, named map[string]binding.ExtendedType, imports map[string]struct{}) {
_, pkg := extendedTypeToGo(et, named)
if pkg != "" {
imports[pkg] = struct{}{}
}
// Additional packages used during decoding.
switch et.Base {
case smartcontract.StringType:
imports["unicode/utf8"] = struct{}{}
imports["errors"] = struct{}{}
case smartcontract.PublicKeyType:
imports["crypto/elliptic"] = struct{}{}
case smartcontract.MapType:
imports["fmt"] = struct{}{}
case smartcontract.ArrayType:
imports["errors"] = struct{}{}
imports["fmt"] = struct{}{}
}
if et.Value != nil {
addETImports(*et.Value, named, imports)
}
if et.Base == smartcontract.MapType {
addETImports(binding.ExtendedType{Base: et.Key}, named, imports)
}
for i := range et.Fields {
addETImports(et.Fields[i].ExtendedType, named, imports)
}
}
func cutPointer(s string) string {
if s[0] == '*' {
return s[1:]
}
return s
}
func toTypeName(s string) string {
return strings.Map(func(c rune) rune {
if c == '.' {
return -1
}
return c
}, strings.ToUpper(s[0:1])+s[1:])
}
func addIndent(str string, ind string) string {
return strings.ReplaceAll(str, "\n", "\n"+ind)
}