Merge pull request #2828 from nspcc-dev/rpcwrapper-structures

Handle structures in the RPC wrapper generator
This commit is contained in:
Roman Khimov 2022-12-06 21:40:16 +07:00 committed by GitHub
commit cad0fab704
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 2220 additions and 126 deletions

View file

@ -401,6 +401,7 @@ func TestAssistedRPCBindings(t *testing.T) {
} }
checkBinding(filepath.Join("testdata", "types")) checkBinding(filepath.Join("testdata", "types"))
checkBinding(filepath.Join("testdata", "structs"))
} }
func TestGenerate_Errors(t *testing.T) { func TestGenerate_Errors(t *testing.T) {

View file

@ -0,0 +1,3 @@
name: "Types"
sourceurl: https://github.com/nspcc-dev/neo-go/
safemethods: ["contract", "block", "transaction", "struct"]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,39 @@
package structs
import (
"github.com/nspcc-dev/neo-go/pkg/interop"
"github.com/nspcc-dev/neo-go/pkg/interop/native/ledger"
"github.com/nspcc-dev/neo-go/pkg/interop/native/management"
)
type Internal struct {
Bool bool
Int int
Bytes []byte
String string
H160 interop.Hash160
H256 interop.Hash256
PK interop.PublicKey
PubKey interop.PublicKey
Sign interop.Signature
ArrOfBytes [][]byte
ArrOfH160 []interop.Hash160
Map map[int][]interop.PublicKey
Struct *Internal
}
func Contract(mc management.Contract) management.Contract {
return mc
}
func Block(b *ledger.Block) *ledger.Block {
return b
}
func Transaction(t *ledger.Transaction) *ledger.Transaction {
return t
}
func Struct(s *Internal) *Internal {
return s
}

View file

@ -1,3 +1,3 @@
name: "Types" name: "Types"
sourceurl: https://github.com/nspcc-dev/neo-go/ sourceurl: https://github.com/nspcc-dev/neo-go/
safemethods: ["bool", "int", "bytes", "string", "hash160", "hash256", "publicKey", "signature", "bools", "ints", "bytess", "strings", "hash160s", "hash256s", "publicKeys", "signatures"] safemethods: ["bool", "int", "bytes", "string", "hash160", "hash256", "publicKey", "signature", "bools", "ints", "bytess", "strings", "hash160s", "hash256s", "publicKeys", "signatures", "aAAStrings", "maps", "crazyMaps"]

View file

@ -2,11 +2,15 @@
package types package types
import ( import (
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/neorpc/result"
"github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap" "github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"math/big" "math/big"
"unicode/utf8"
) )
// Hash contains contract hash. // Hash contains contract hash.
@ -28,6 +32,64 @@ func NewReader(invoker Invoker) *ContractReader {
} }
// AAAStrings invokes `aAAStrings` method of contract.
func (c *ContractReader) AAAStrings(s [][][]string) ([][][]string, error) {
return func (item stackitem.Item, err error) ([][][]string, error) {
if err != nil {
return nil, err
}
return func (item stackitem.Item) ([][][]string, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([][][]string, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) ([][]string, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([][]string, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) ([]string, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([]string, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (item)
} (unwrap.Item(c.invoker.Call(Hash, "aAAStrings", s)))
}
// Bool invokes `bool` method of contract. // Bool invokes `bool` method of contract.
func (c *ContractReader) Bool(b bool) (bool, error) { func (c *ContractReader) Bool(b bool) (bool, error) {
return unwrap.Bool(c.invoker.Call(Hash, "bool", b)) return unwrap.Bool(c.invoker.Call(Hash, "bool", b))
@ -48,6 +110,97 @@ func (c *ContractReader) Bytess(b [][]byte) ([][]byte, error) {
return unwrap.ArrayOfBytes(c.invoker.Call(Hash, "bytess", b)) return unwrap.ArrayOfBytes(c.invoker.Call(Hash, "bytess", b))
} }
// CrazyMaps invokes `crazyMaps` method of contract.
func (c *ContractReader) CrazyMaps(m map[*big.Int][]map[string][]util.Uint160) (map[*big.Int][]map[string][]util.Uint160, error) {
return func (item stackitem.Item, err error) (map[*big.Int][]map[string][]util.Uint160, error) {
if err != nil {
return nil, err
}
return func (item stackitem.Item) (map[*big.Int][]map[string][]util.Uint160, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(map[*big.Int][]map[string][]util.Uint160)
for i := range m {
k, err := m[i].Key.TryInteger()
if err != nil {
return nil, fmt.Errorf("key %d: %w", i, err)
}
v, err := func (item stackitem.Item) ([]map[string][]util.Uint160, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([]map[string][]util.Uint160, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) (map[string][]util.Uint160, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(map[string][]util.Uint160)
for i := range m {
k, err := func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (m[i].Key)
if err != nil {
return nil, fmt.Errorf("key %d: %w", i, err)
}
v, err := func (item stackitem.Item) ([]util.Uint160, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make([]util.Uint160, len(arr))
for i := range res {
res[i], err = func (item stackitem.Item) (util.Uint160, error) {
b, err := item.TryBytes()
if err != nil {
return util.Uint160{}, err
}
u, err := util.Uint160DecodeBytesBE(b)
if err != nil {
return util.Uint160{}, err
}
return u, nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (m[i].Value)
if err != nil {
return nil, fmt.Errorf("value %d: %w", i, err)
}
res[k] = v
}
return res, nil
} (arr[i])
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (m[i].Value)
if err != nil {
return nil, fmt.Errorf("value %d: %w", i, err)
}
res[k] = v
}
return res, nil
} (item)
} (unwrap.Item(c.invoker.Call(Hash, "crazyMaps", m)))
}
// Hash160 invokes `hash160` method of contract. // Hash160 invokes `hash160` method of contract.
func (c *ContractReader) Hash160(h util.Uint160) (util.Uint160, error) { func (c *ContractReader) Hash160(h util.Uint160) (util.Uint160, error) {
return unwrap.Uint160(c.invoker.Call(Hash, "hash160", h)) return unwrap.Uint160(c.invoker.Call(Hash, "hash160", h))
@ -78,6 +231,52 @@ func (c *ContractReader) Ints(i []*big.Int) ([]*big.Int, error) {
return unwrap.ArrayOfBigInts(c.invoker.Call(Hash, "ints", i)) return unwrap.ArrayOfBigInts(c.invoker.Call(Hash, "ints", i))
} }
// Maps invokes `maps` method of contract.
func (c *ContractReader) Maps(m map[string]string) (map[string]string, error) {
return func (item stackitem.Item, err error) (map[string]string, error) {
if err != nil {
return nil, err
}
return func (item stackitem.Item) (map[string]string, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(map[string]string)
for i := range m {
k, err := func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (m[i].Key)
if err != nil {
return nil, fmt.Errorf("key %d: %w", i, err)
}
v, err := func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (m[i].Value)
if err != nil {
return nil, fmt.Errorf("value %d: %w", i, err)
}
res[k] = v
}
return res, nil
} (item)
} (unwrap.Item(c.invoker.Call(Hash, "maps", m)))
}
// PublicKey invokes `publicKey` method of contract. // PublicKey invokes `publicKey` method of contract.
func (c *ContractReader) PublicKey(k *keys.PublicKey) (*keys.PublicKey, error) { func (c *ContractReader) PublicKey(k *keys.PublicKey) (*keys.PublicKey, error) {
return unwrap.PublicKey(c.invoker.Call(Hash, "publicKey", k)) return unwrap.PublicKey(c.invoker.Call(Hash, "publicKey", k))

View file

@ -67,3 +67,15 @@ func PublicKeys(k []interop.PublicKey) []interop.PublicKey {
func Signatures(s []interop.Signature) []interop.Signature { func Signatures(s []interop.Signature) []interop.Signature {
return nil return nil
} }
func AAAStrings(s [][][]string) [][][]string {
return s
}
func Maps(m map[string]string) map[string]string {
return m
}
func CrazyMaps(m map[int][]map[string][]interop.Hash160) map[int][]map[string][]interop.Hash160 {
return m
}

View file

@ -459,8 +459,9 @@ result. This pair can then be used in Invoker `TraverseIterator` method to
retrieve actual resulting items. retrieve actual resulting items.
Go contracts can also make use of additional type data from bindings Go contracts can also make use of additional type data from bindings
configuration file generated during compilation. At the moment it allows to configuration file generated during compilation. This can cover arrays, maps
generate proper wrappers for simple array types, but doesn't cover structures: and structures. Notice that structured types returned by methods can't be Null
at the moment (see #2795).
``` ```
$ ./bin/neo-go contract compile -i contract.go --config contract.yml -o contract.nef --manifest manifest.json --bindings contract.bindings.yml $ ./bin/neo-go contract compile -i contract.go --config contract.yml -o contract.nef --manifest manifest.json --bindings contract.bindings.yml

View file

@ -914,7 +914,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.convertMap(n) c.convertMap(n)
default: default:
if tn, ok := t.(*types.Named); ok && isInteropPath(tn.String()) { if tn, ok := t.(*types.Named); ok && isInteropPath(tn.String()) {
st, _, _ := scAndVMInteropTypeFromExpr(tn, false) st, _, _, _ := scAndVMInteropTypeFromExpr(tn, false)
expectedLen := -1 expectedLen := -1
switch st { switch st {
case smartcontract.Hash160Type: case smartcontract.Hash160Type:

View file

@ -287,14 +287,27 @@ func CompileAndSave(src string, o *Options) ([]byte, error) {
cfg := binding.NewConfig() cfg := binding.NewConfig()
cfg.Package = di.MainPkg cfg.Package = di.MainPkg
for _, m := range di.Methods { for _, m := range di.Methods {
if !m.IsExported {
continue
}
for _, p := range m.Parameters { for _, p := range m.Parameters {
pname := m.Name.Name + "." + p.Name
if p.RealType.TypeName != "" { if p.RealType.TypeName != "" {
cfg.Overrides[m.Name.Name+"."+p.Name] = p.RealType cfg.Overrides[pname] = p.RealType
}
if p.ExtendedType != nil {
cfg.Types[pname] = *p.ExtendedType
} }
} }
if m.ReturnTypeReal.TypeName != "" { if m.ReturnTypeReal.TypeName != "" {
cfg.Overrides[m.Name.Name] = m.ReturnTypeReal cfg.Overrides[m.Name.Name] = m.ReturnTypeReal
} }
if m.ReturnTypeExtended != nil {
cfg.Types[m.Name.Name] = *m.ReturnTypeExtended
}
}
if len(di.NamedTypes) > 0 {
cfg.NamedTypes = di.NamedTypes
} }
data, err := yaml.Marshal(&cfg) data, err := yaml.Marshal(&cfg)
if err != nil { if err != nil {

View file

@ -26,7 +26,10 @@ type DebugInfo struct {
Hash util.Uint160 `json:"hash"` Hash util.Uint160 `json:"hash"`
Documents []string `json:"documents"` Documents []string `json:"documents"`
Methods []MethodDebugInfo `json:"methods"` Methods []MethodDebugInfo `json:"methods"`
Events []EventDebugInfo `json:"events"` // NamedTypes are exported structured types that have some name (even
// if the original structure doesn't) and a number of internal fields.
NamedTypes map[string]binding.ExtendedType `json:"-"`
Events []EventDebugInfo `json:"events"`
// EmittedEvents contains events occurring in code. // EmittedEvents contains events occurring in code.
EmittedEvents map[string][][]string `json:"-"` EmittedEvents map[string][][]string `json:"-"`
// InvokedContracts contains foreign contract invocations. // InvokedContracts contains foreign contract invocations.
@ -55,6 +58,8 @@ type MethodDebugInfo struct {
ReturnType string `json:"return"` ReturnType string `json:"return"`
// ReturnTypeReal is the method's return type as specified in Go code. // ReturnTypeReal is the method's return type as specified in Go code.
ReturnTypeReal binding.Override `json:"-"` ReturnTypeReal binding.Override `json:"-"`
// ReturnTypeExtended is the method's return type with additional data.
ReturnTypeExtended *binding.ExtendedType `json:"-"`
// ReturnTypeSC is a return type to use in manifest. // ReturnTypeSC is a return type to use in manifest.
ReturnTypeSC smartcontract.ParamType `json:"-"` ReturnTypeSC smartcontract.ParamType `json:"-"`
Variables []string `json:"variables"` Variables []string `json:"variables"`
@ -100,10 +105,11 @@ type DebugRange struct {
// DebugParam represents the variables's name and type. // DebugParam represents the variables's name and type.
type DebugParam struct { type DebugParam struct {
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` Type string `json:"type"`
RealType binding.Override `json:"-"` RealType binding.Override `json:"-"`
TypeSC smartcontract.ParamType `json:"-"` ExtendedType *binding.ExtendedType `json:"-"`
TypeSC smartcontract.ParamType `json:"-"`
} }
func (c *codegen) saveSequencePoint(n ast.Node) { func (c *codegen) saveSequencePoint(n ast.Node) {
@ -185,8 +191,9 @@ func (c *codegen) emitDebugInfo(contract []byte) *DebugInfo {
} }
start := len(d.Methods) start := len(d.Methods)
d.NamedTypes = make(map[string]binding.ExtendedType)
for name, scope := range c.funcs { for name, scope := range c.funcs {
m := c.methodInfoFromScope(name, scope) m := c.methodInfoFromScope(name, scope, d.NamedTypes)
if m.Range.Start == m.Range.End { if m.Range.Start == m.Range.End {
continue continue
} }
@ -201,7 +208,7 @@ func (c *codegen) emitDebugInfo(contract []byte) *DebugInfo {
} }
func (c *codegen) registerDebugVariable(name string, expr ast.Expr) { func (c *codegen) registerDebugVariable(name string, expr ast.Expr) {
_, vt, _ := c.scAndVMTypeFromExpr(expr) _, vt, _, _ := c.scAndVMTypeFromExpr(expr, nil)
if c.scope == nil { if c.scope == nil {
c.staticVariables = append(c.staticVariables, name+","+vt.String()) c.staticVariables = append(c.staticVariables, name+","+vt.String())
return return
@ -209,24 +216,25 @@ func (c *codegen) registerDebugVariable(name string, expr ast.Expr) {
c.scope.variables = append(c.scope.variables, name+","+vt.String()) c.scope.variables = append(c.scope.variables, name+","+vt.String())
} }
func (c *codegen) methodInfoFromScope(name string, scope *funcScope) *MethodDebugInfo { func (c *codegen) methodInfoFromScope(name string, scope *funcScope, exts map[string]binding.ExtendedType) *MethodDebugInfo {
ps := scope.decl.Type.Params ps := scope.decl.Type.Params
params := make([]DebugParam, 0, ps.NumFields()) params := make([]DebugParam, 0, ps.NumFields())
for i := range ps.List { for i := range ps.List {
for j := range ps.List[i].Names { for j := range ps.List[i].Names {
st, vt, rt := c.scAndVMTypeFromExpr(ps.List[i].Type) st, vt, rt, et := c.scAndVMTypeFromExpr(ps.List[i].Type, exts)
params = append(params, DebugParam{ params = append(params, DebugParam{
Name: ps.List[i].Names[j].Name, Name: ps.List[i].Names[j].Name,
Type: vt.String(), Type: vt.String(),
RealType: rt, ExtendedType: et,
TypeSC: st, RealType: rt,
TypeSC: st,
}) })
} }
} }
ss := strings.Split(name, ".") ss := strings.Split(name, ".")
name = ss[len(ss)-1] name = ss[len(ss)-1]
r, n := utf8.DecodeRuneInString(name) r, n := utf8.DecodeRuneInString(name)
st, vt, rt := c.scAndVMReturnTypeFromScope(scope) st, vt, rt, et := c.scAndVMReturnTypeFromScope(scope, exts)
return &MethodDebugInfo{ return &MethodDebugInfo{
ID: name, ID: name,
@ -234,45 +242,52 @@ func (c *codegen) methodInfoFromScope(name string, scope *funcScope) *MethodDebu
Name: string(unicode.ToLower(r)) + name[n:], Name: string(unicode.ToLower(r)) + name[n:],
Namespace: scope.pkg.Name(), Namespace: scope.pkg.Name(),
}, },
IsExported: scope.decl.Name.IsExported(), IsExported: scope.decl.Name.IsExported(),
IsFunction: scope.decl.Recv == nil, IsFunction: scope.decl.Recv == nil,
Range: scope.rng, Range: scope.rng,
Parameters: params, Parameters: params,
ReturnType: vt, ReturnType: vt,
ReturnTypeReal: rt, ReturnTypeExtended: et,
ReturnTypeSC: st, ReturnTypeReal: rt,
SeqPoints: c.sequencePoints[name], ReturnTypeSC: st,
Variables: scope.variables, SeqPoints: c.sequencePoints[name],
Variables: scope.variables,
} }
} }
func (c *codegen) scAndVMReturnTypeFromScope(scope *funcScope) (smartcontract.ParamType, string, binding.Override) { func (c *codegen) scAndVMReturnTypeFromScope(scope *funcScope, exts map[string]binding.ExtendedType) (smartcontract.ParamType, string, binding.Override, *binding.ExtendedType) {
results := scope.decl.Type.Results results := scope.decl.Type.Results
switch results.NumFields() { switch results.NumFields() {
case 0: case 0:
return smartcontract.VoidType, "Void", binding.Override{} return smartcontract.VoidType, "Void", binding.Override{}, nil
case 1: case 1:
st, vt, s := c.scAndVMTypeFromExpr(results.List[0].Type) st, vt, s, et := c.scAndVMTypeFromExpr(results.List[0].Type, exts)
return st, vt.String(), s return st, vt.String(), s, et
default: default:
// multiple return values are not supported in debugger // multiple return values are not supported in debugger
return smartcontract.AnyType, "Any", binding.Override{} return smartcontract.AnyType, "Any", binding.Override{}, nil
} }
} }
func scAndVMInteropTypeFromExpr(named *types.Named, isPointer bool) (smartcontract.ParamType, stackitem.Type, binding.Override) { func scAndVMInteropTypeFromExpr(named *types.Named, isPointer bool) (smartcontract.ParamType, stackitem.Type, binding.Override, *binding.ExtendedType) {
name := named.Obj().Name() name := named.Obj().Name()
pkg := named.Obj().Pkg().Name() pkg := named.Obj().Pkg().Name()
switch pkg { switch pkg {
case "ledger", "contract": case "ledger", "management":
switch name {
case "ParameterType", "SignerScope", "WitnessAction", "WitnessConditionType", "VMState":
return smartcontract.IntegerType, stackitem.IntegerT, binding.Override{TypeName: "int"}, nil
}
// Block, Transaction, Contract.
typeName := pkg + "." + name typeName := pkg + "." + name
et := &binding.ExtendedType{Base: smartcontract.ArrayType, Name: typeName}
if isPointer { if isPointer {
typeName = "*" + typeName typeName = "*" + typeName
} }
return smartcontract.ArrayType, stackitem.ArrayT, binding.Override{ return smartcontract.ArrayType, stackitem.ArrayT, binding.Override{
Package: named.Obj().Pkg().Path(), Package: named.Obj().Pkg().Path(),
TypeName: typeName, TypeName: typeName,
} // Block, Transaction, Contract }, et
case "interop": case "interop":
if name != "Interface" { if name != "Interface" {
over := binding.Override{ over := binding.Override{
@ -281,26 +296,29 @@ func scAndVMInteropTypeFromExpr(named *types.Named, isPointer bool) (smartcontra
} }
switch name { switch name {
case "Hash160": case "Hash160":
return smartcontract.Hash160Type, stackitem.ByteArrayT, over return smartcontract.Hash160Type, stackitem.ByteArrayT, over, nil
case "Hash256": case "Hash256":
return smartcontract.Hash256Type, stackitem.ByteArrayT, over return smartcontract.Hash256Type, stackitem.ByteArrayT, over, nil
case "PublicKey": case "PublicKey":
return smartcontract.PublicKeyType, stackitem.ByteArrayT, over return smartcontract.PublicKeyType, stackitem.ByteArrayT, over, nil
case "Signature": case "Signature":
return smartcontract.SignatureType, stackitem.ByteArrayT, over return smartcontract.SignatureType, stackitem.ByteArrayT, over, nil
} }
} }
} }
return smartcontract.InteropInterfaceType, stackitem.InteropT, binding.Override{TypeName: "interface{}"} return smartcontract.InteropInterfaceType,
stackitem.InteropT,
binding.Override{TypeName: "interface{}"},
&binding.ExtendedType{Base: smartcontract.InteropInterfaceType, Interface: "iterator"} // Temporarily all interops are iterators.
} }
func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr) (smartcontract.ParamType, stackitem.Type, binding.Override) { func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr, exts map[string]binding.ExtendedType) (smartcontract.ParamType, stackitem.Type, binding.Override, *binding.ExtendedType) {
return c.scAndVMTypeFromType(c.typeOf(typ)) return c.scAndVMTypeFromType(c.typeOf(typ), exts)
} }
func (c *codegen) scAndVMTypeFromType(t types.Type) (smartcontract.ParamType, stackitem.Type, binding.Override) { func (c *codegen) scAndVMTypeFromType(t types.Type, exts map[string]binding.ExtendedType) (smartcontract.ParamType, stackitem.Type, binding.Override, *binding.ExtendedType) {
if t == nil { if t == nil {
return smartcontract.AnyType, stackitem.AnyT, binding.Override{TypeName: "interface{}"} return smartcontract.AnyType, stackitem.AnyT, binding.Override{TypeName: "interface{}"}, nil
} }
var isPtr bool var isPtr bool
@ -314,10 +332,16 @@ func (c *codegen) scAndVMTypeFromType(t types.Type) (smartcontract.ParamType, st
} }
if isNamed { if isNamed {
if isInteropPath(named.String()) { if isInteropPath(named.String()) {
return scAndVMInteropTypeFromExpr(named, isPtr) st, vt, over, et := scAndVMInteropTypeFromExpr(named, isPtr)
if et != nil && et.Base == smartcontract.ArrayType && exts != nil && exts[et.Name].Name != et.Name {
_ = c.genStructExtended(named.Underlying().(*types.Struct), et.Name, exts)
}
return st, vt, over, et
} }
} }
if ptr, isPtr := t.(*types.Pointer); isPtr {
t = ptr.Elem()
}
var over binding.Override var over binding.Override
switch t := t.Underlying().(type) { switch t := t.Underlying().(type) {
case *types.Basic: case *types.Basic:
@ -325,43 +349,103 @@ func (c *codegen) scAndVMTypeFromType(t types.Type) (smartcontract.ParamType, st
switch { switch {
case info&types.IsInteger != 0: case info&types.IsInteger != 0:
over.TypeName = "int" over.TypeName = "int"
return smartcontract.IntegerType, stackitem.IntegerT, over return smartcontract.IntegerType, stackitem.IntegerT, over, nil
case info&types.IsBoolean != 0: case info&types.IsBoolean != 0:
over.TypeName = "bool" over.TypeName = "bool"
return smartcontract.BoolType, stackitem.BooleanT, over return smartcontract.BoolType, stackitem.BooleanT, over, nil
case info&types.IsString != 0: case info&types.IsString != 0:
over.TypeName = "string" over.TypeName = "string"
return smartcontract.StringType, stackitem.ByteArrayT, over return smartcontract.StringType, stackitem.ByteArrayT, over, nil
default: default:
over.TypeName = "interface{}" over.TypeName = "interface{}"
return smartcontract.AnyType, stackitem.AnyT, over return smartcontract.AnyType, stackitem.AnyT, over, nil
} }
case *types.Map: case *types.Map:
_, _, over := c.scAndVMTypeFromType(t.Elem()) et := &binding.ExtendedType{
Base: smartcontract.MapType,
}
et.Key, _, _, _ = c.scAndVMTypeFromType(t.Key(), exts)
vt, _, over, vet := c.scAndVMTypeFromType(t.Elem(), exts)
et.Value = vet
if et.Value == nil {
et.Value = &binding.ExtendedType{Base: vt}
}
over.TypeName = "map[" + t.Key().String() + "]" + over.TypeName over.TypeName = "map[" + t.Key().String() + "]" + over.TypeName
return smartcontract.MapType, stackitem.MapT, over return smartcontract.MapType, stackitem.MapT, over, et
case *types.Struct: case *types.Struct:
if isNamed { if isNamed {
over.Package = named.Obj().Pkg().Path() over.Package = named.Obj().Pkg().Path()
over.TypeName = named.Obj().Pkg().Name() + "." + named.Obj().Name() over.TypeName = named.Obj().Pkg().Name() + "." + named.Obj().Name()
_ = c.genStructExtended(t, over.TypeName, exts)
} else {
name := "unnamed"
if exts != nil {
for exts[name].Name == name {
name = name + "X"
}
_ = c.genStructExtended(t, name, exts)
}
} }
return smartcontract.ArrayType, stackitem.StructT, over return smartcontract.ArrayType, stackitem.StructT, over,
&binding.ExtendedType{ // Value-less, refer to exts.
Base: smartcontract.ArrayType,
Name: over.TypeName,
}
case *types.Slice: case *types.Slice:
if isByte(t.Elem()) { if isByte(t.Elem()) {
over.TypeName = "[]byte" over.TypeName = "[]byte"
return smartcontract.ByteArrayType, stackitem.ByteArrayT, over return smartcontract.ByteArrayType, stackitem.ByteArrayT, over, nil
}
et := &binding.ExtendedType{
Base: smartcontract.ArrayType,
}
vt, _, over, vet := c.scAndVMTypeFromType(t.Elem(), exts)
et.Value = vet
if et.Value == nil {
et.Value = &binding.ExtendedType{
Base: vt,
}
} }
_, _, over := c.scAndVMTypeFromType(t.Elem())
if over.TypeName != "" { if over.TypeName != "" {
over.TypeName = "[]" + over.TypeName over.TypeName = "[]" + over.TypeName
} }
return smartcontract.ArrayType, stackitem.ArrayT, over return smartcontract.ArrayType, stackitem.ArrayT, over, et
default: default:
over.TypeName = "interface{}" over.TypeName = "interface{}"
return smartcontract.AnyType, stackitem.AnyT, over return smartcontract.AnyType, stackitem.AnyT, over, nil
} }
} }
func (c *codegen) genStructExtended(t *types.Struct, name string, exts map[string]binding.ExtendedType) *binding.ExtendedType {
var et *binding.ExtendedType
if exts != nil {
if exts[name].Name != name {
et = &binding.ExtendedType{
Base: smartcontract.ArrayType,
Name: name,
Fields: make([]binding.FieldExtendedType, t.NumFields()),
}
exts[name] = *et // Prefill to solve recursive structures.
for i := range et.Fields {
field := t.Field(i)
ft, _, _, fet := c.scAndVMTypeFromType(field.Type(), exts)
if fet == nil {
et.Fields[i].ExtendedType.Base = ft
} else {
et.Fields[i].ExtendedType = *fet
}
et.Fields[i].Field = field.Name()
}
exts[name] = *et // Set real structure data.
} else {
et = new(binding.ExtendedType)
*et = exts[name]
}
}
return et
}
// MarshalJSON implements the json.Marshaler interface. // MarshalJSON implements the json.Marshaler interface.
func (d *DebugRange) MarshalJSON() ([]byte, error) { func (d *DebugRange) MarshalJSON() ([]byte, error) {
return []byte(`"` + strconv.FormatUint(uint64(d.Start), 10) + `-` + return []byte(`"` + strconv.FormatUint(uint64(d.Start), 10) + `-` +

View file

@ -175,7 +175,7 @@ func (c *codegen) processNotify(f *funcScope, args []ast.Expr, hasEllipsis bool)
params := make([]string, 0, len(args[1:])) params := make([]string, 0, len(args[1:]))
vParams := make([]*stackitem.Type, 0, len(args[1:])) vParams := make([]*stackitem.Type, 0, len(args[1:]))
for _, p := range args[1:] { for _, p := range args[1:] {
st, vt, _ := c.scAndVMTypeFromExpr(p) st, vt, _, _ := c.scAndVMTypeFromExpr(p, nil)
params = append(params, st.String()) params = append(params, st.String())
vParams = append(vParams, &vt) vParams = append(vParams, &vt)
} }

View file

@ -48,12 +48,28 @@ const Hash = "{{ .Hash }}"
type ( type (
// Config contains parameter for the generated binding. // Config contains parameter for the generated binding.
Config struct { Config struct {
Package string `yaml:"package,omitempty"` Package string `yaml:"package,omitempty"`
Manifest *manifest.Manifest `yaml:"-"` Manifest *manifest.Manifest `yaml:"-"`
Hash util.Uint160 `yaml:"hash,omitempty"` Hash util.Uint160 `yaml:"hash,omitempty"`
Overrides map[string]Override `yaml:"overrides,omitempty"` Overrides map[string]Override `yaml:"overrides,omitempty"`
CallFlags map[string]callflag.CallFlag `yaml:"callflags,omitempty"` CallFlags map[string]callflag.CallFlag `yaml:"callflags,omitempty"`
Output io.Writer `yaml:"-"` NamedTypes map[string]ExtendedType `yaml:"namedtypes,omitempty"`
Types map[string]ExtendedType `yaml:"types,omitempty"`
Output io.Writer `yaml:"-"`
}
ExtendedType struct {
Base smartcontract.ParamType `yaml:"base"`
Name string `yaml:"name,omitempty"` // Structure name, omitted for arrays, interfaces and maps.
Interface string `yaml:"interface,omitempty"` // Interface type name, "iterator" only for now.
Key smartcontract.ParamType `yaml:"key,omitempty"` // Key type (only simple types can be used for keys) for maps.
Value *ExtendedType `yaml:"value,omitempty"` // Value type for iterators and arrays.
Fields []FieldExtendedType `yaml:"fields,omitempty"` // Ordered type data for structure fields.
}
FieldExtendedType struct {
Field string `yaml:"field"`
ExtendedType `yaml:",inline"`
} }
ContractTmpl struct { ContractTmpl struct {
@ -84,8 +100,10 @@ var srcTemplate = template.Must(template.New("generate").Parse(srcTmpl))
// NewConfig initializes and returns a new config instance. // NewConfig initializes and returns a new config instance.
func NewConfig() Config { func NewConfig() Config {
return Config{ return Config{
Overrides: make(map[string]Override), Overrides: make(map[string]Override),
CallFlags: make(map[string]callflag.CallFlag), CallFlags: make(map[string]callflag.CallFlag),
NamedTypes: make(map[string]ExtendedType),
Types: make(map[string]ExtendedType),
} }
} }
@ -101,8 +119,8 @@ func Generate(cfg Config) error {
return srcTemplate.Execute(cfg.Output, ctr) return srcTemplate.Execute(cfg.Output, ctr)
} }
func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]Override) (string, string) { func scTypeToGo(name string, typ smartcontract.ParamType, cfg *Config) (string, string) {
if over, ok := overrides[name]; ok { if over, ok := cfg.Overrides[name]; ok {
return over.TypeName, over.Package return over.TypeName, over.Package
} }
@ -141,7 +159,7 @@ func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]O
// TemplateFromManifest create a contract template using the given configuration // TemplateFromManifest create a contract template using the given configuration
// and type conversion function. It assumes manifest to be present in the // and type conversion function. It assumes manifest to be present in the
// configuration and assumes it to be correct (passing IsValid check). // configuration and assumes it to be correct (passing IsValid check).
func TemplateFromManifest(cfg Config, scTypeConverter func(string, smartcontract.ParamType, map[string]Override) (string, string)) ContractTmpl { func TemplateFromManifest(cfg Config, scTypeConverter func(string, smartcontract.ParamType, *Config) (string, string)) ContractTmpl {
hStr := "" hStr := ""
for _, b := range cfg.Hash.BytesBE() { for _, b := range cfg.Hash.BytesBE() {
hStr += fmt.Sprintf("\\x%02x", b) hStr += fmt.Sprintf("\\x%02x", b)
@ -203,7 +221,7 @@ func TemplateFromManifest(cfg Config, scTypeConverter func(string, smartcontract
var varnames = make(map[string]bool) var varnames = make(map[string]bool)
for i := range m.Parameters { for i := range m.Parameters {
name := m.Parameters[i].Name name := m.Parameters[i].Name
typeStr, pkg := scTypeConverter(m.Name+"."+name, m.Parameters[i].Type, cfg.Overrides) typeStr, pkg := scTypeConverter(m.Name+"."+name, m.Parameters[i].Type, &cfg)
if pkg != "" { if pkg != "" {
imports[pkg] = struct{}{} imports[pkg] = struct{}{}
} }
@ -220,7 +238,7 @@ func TemplateFromManifest(cfg Config, scTypeConverter func(string, smartcontract
}) })
} }
typeStr, pkg := scTypeConverter(m.Name, m.ReturnType, cfg.Overrides) typeStr, pkg := scTypeConverter(m.Name, m.ReturnType, &cfg)
if pkg != "" { if pkg != "" {
imports[pkg] = struct{}{} imports[pkg] = struct{}{}
} }

View file

@ -3,6 +3,7 @@ package rpcbinding
import ( import (
"fmt" "fmt"
"sort" "sort"
"strings"
"text/template" "text/template"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
@ -18,14 +19,20 @@ func (c *ContractReader) {{.Name}}({{range $index, $arg := .Arguments -}}
{{- if ne $index 0}}, {{end}} {{- if ne $index 0}}, {{end}}
{{- .Name}} {{.Type}} {{- .Name}} {{.Type}}
{{- end}}) {{if .ReturnType }}({{ .ReturnType }}, error) { {{- end}}) {{if .ReturnType }}({{ .ReturnType }}, error) {
return unwrap.{{.CallFlag}}(c.invoker.Call(Hash, "{{ .NameABI }}"{{/* CallFlag field is used for function name */}} return {{if and (not .ItemTo) (eq .Unwrapper "Item")}}func (item stackitem.Item, err error) ({{ .ReturnType }}, error) {
{{- range $arg := .Arguments -}}, {{.Name}}{{end}})) if err != nil {
return nil, err
}
return {{addIndent (etTypeConverter .ExtendedReturn "item") "\t"}}
} ( {{- end -}} {{if .ItemTo -}} itemTo{{ .ItemTo }}( {{- end -}}
unwrap.{{.Unwrapper}}(c.invoker.Call(Hash, "{{ .NameABI }}"
{{- range $arg := .Arguments -}}, {{.Name}}{{end -}} )) {{- if or .ItemTo (eq .Unwrapper "Item") -}} ) {{- end}}
{{- else -}} (*result.Invoke, error) { {{- else -}} (*result.Invoke, error) {
c.invoker.Call(Hash, "{{ .NameABI }}" c.invoker.Call(Hash, "{{ .NameABI }}"
{{- range $arg := .Arguments -}}, {{.Name}}{{end}}) {{- range $arg := .Arguments -}}, {{.Name}}{{end}})
{{- end}} {{- end}}
} }
{{- if eq .CallFlag "SessionIterator"}} {{- if eq .Unwrapper "SessionIterator"}}
// {{.Name}}Expanded is similar to {{.Name}} (uses the same contract // {{.Name}}Expanded is similar to {{.Name}} (uses the same contract
// method), but can be useful if the server used doesn't support sessions and // method), but can be useful if the server used doesn't support sessions and
@ -101,6 +108,14 @@ import (
// Hash contains contract hash. // Hash contains contract hash.
var Hash = {{ .Hash }} var Hash = {{ .Hash }}
{{range $name, $typ := .NamedTypes}}
// {{toTypeName $name}} is a contract-specific {{$name}} type used by its methods.
type {{toTypeName $name}} struct {
{{- range $m := $typ.Fields}}
{{.Field}} {{etTypeToStr .ExtendedType}}
{{- end}}
}
{{end -}}
{{if .HasReader}}// Invoker is used by ContractReader to call various safe methods. {{if .HasReader}}// Invoker is used by ContractReader to call various safe methods.
type Invoker interface { type Invoker interface {
{{if or .IsNep11D .IsNep11ND}} nep11.Invoker {{if or .IsNep11D .IsNep11ND}} nep11.Invoker
@ -199,15 +214,41 @@ func New(actor Actor) *Contract {
{{end}} {{end}}
{{- range $m := .Methods}} {{- range $m := .Methods}}
{{template "METHOD" $m }} {{template "METHOD" $m }}
{{end}}` {{end}}
{{- range $name, $typ := .NamedTypes}}
// itemTo{{toTypeName $name}} converts stack item into *{{toTypeName $name}}.
func itemTo{{toTypeName $name}}(item stackitem.Item, err error) (*{{toTypeName $name}}, error) {
if err != nil {
return nil, err
}
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
if len(arr) != {{len $typ.Fields}} {
return nil, errors.New("wrong number of structure elements")
}
var srcTemplate = template.Must(template.New("generate").Parse(srcTmpl)) var res = new({{toTypeName $name}})
{{if len .Fields}} var index = -1
{{- range $m := $typ.Fields}}
index++
res.{{.Field}}, err = {{etTypeConverter .ExtendedType "arr[index]"}}
if err != nil {
return nil, fmt.Errorf("field {{.Field}}: %w", err)
}
{{end}}
{{end}}
return res, err
}
{{end}}`
type ( type (
ContractTmpl struct { ContractTmpl struct {
binding.ContractTmpl binding.ContractTmpl
SafeMethods []binding.MethodTmpl SafeMethods []SafeMethodTmpl
NamedTypes map[string]binding.ExtendedType
IsNep11D bool IsNep11D bool
IsNep11ND bool IsNep11ND bool
@ -217,6 +258,13 @@ type (
HasWriter bool HasWriter bool
HasIterator bool HasIterator bool
} }
SafeMethodTmpl struct {
binding.MethodTmpl
Unwrapper string
ItemTo string
ExtendedReturn binding.ExtendedType
}
) )
// NewConfig initializes and returns a new config instance. // NewConfig initializes and returns a new config instance.
@ -268,6 +316,18 @@ func Generate(cfg binding.Config) error {
ctr.ContractTmpl = binding.TemplateFromManifest(cfg, scTypeToGo) ctr.ContractTmpl = binding.TemplateFromManifest(cfg, scTypeToGo)
ctr = scTemplateToRPC(cfg, ctr, imports) ctr = scTemplateToRPC(cfg, ctr, imports)
ctr.NamedTypes = cfg.NamedTypes
var srcTemplate = template.Must(template.New("generate").Funcs(template.FuncMap{
"addIndent": addIndent,
"etTypeConverter": etTypeConverter,
"etTypeToStr": func(et binding.ExtendedType) string {
r, _ := extendedTypeToGo(et, ctr.NamedTypes)
return r
},
"toTypeName": toTypeName,
"cutPointer": cutPointer,
}).Parse(srcTmpl))
return srcTemplate.Execute(cfg.Output, ctr) return srcTemplate.Execute(cfg.Output, ctr)
} }
@ -295,31 +355,8 @@ func dropStdMethods(meths []manifest.Method, std *standard.Standard) []manifest.
return meths return meths
} }
func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]binding.Override) (string, string) { func extendedTypeToGo(et binding.ExtendedType, named map[string]binding.ExtendedType) (string, string) {
over, ok := overrides[name] switch et.Base {
if ok {
switch over.TypeName {
case "[]bool":
return "[]bool", ""
case "[]int", "[]uint", "[]int8", "[]uint8", "[]int16",
"[]uint16", "[]int32", "[]uint32", "[]int64", "[]uint64":
return "[]*big.Int", "math/big"
case "[][]byte":
return "[][]byte", ""
case "[]string":
return "[]string", ""
case "[]interop.Hash160":
return "[]util.Uint160", "github.com/nspcc-dev/neo-go/pkg/util"
case "[]interop.Hash256":
return "[]util.Uint256", "github.com/nspcc-dev/neo-go/pkg/util"
case "[]interop.PublicKey":
return "keys.PublicKeys", "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
case "[]interop.Signature":
return "[][]byte", ""
}
}
switch typ {
case smartcontract.AnyType: case smartcontract.AnyType:
return "interface{}", "" return "interface{}", ""
case smartcontract.BoolType: case smartcontract.BoolType:
@ -339,16 +376,148 @@ func scTypeToGo(name string, typ smartcontract.ParamType, overrides map[string]b
case smartcontract.SignatureType: case smartcontract.SignatureType:
return "[]byte", "" return "[]byte", ""
case smartcontract.ArrayType: case smartcontract.ArrayType:
if len(et.Name) > 0 {
return "*" + toTypeName(et.Name), "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
} else if et.Value != nil {
if et.Value.Base == smartcontract.PublicKeyType { // Special array wrapper.
return "keys.PublicKeys", "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
}
sub, pkg := extendedTypeToGo(*et.Value, named)
return "[]" + sub, pkg
}
return "[]interface{}", "" return "[]interface{}", ""
case smartcontract.MapType: case smartcontract.MapType:
return "*stackitem.Map", "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" kt, _ := extendedTypeToGo(binding.ExtendedType{Base: et.Key}, named)
vt, _ := extendedTypeToGo(*et.Value, named)
return "map[" + kt + "]" + vt, "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
case smartcontract.InteropInterfaceType: case smartcontract.InteropInterfaceType:
return "interface{}", "" return "interface{}", ""
case smartcontract.VoidType: case smartcontract.VoidType:
return "", "" return "", ""
default:
panic("unreachable")
} }
panic("unreachable")
}
func etTypeConverter(et binding.ExtendedType, v string) string {
switch et.Base {
case smartcontract.AnyType:
return v + ".Value(), nil"
case smartcontract.BoolType:
return v + ".TryBool()"
case smartcontract.IntegerType:
return v + ".TryInteger()"
case smartcontract.ByteArrayType, smartcontract.SignatureType:
return v + ".TryBytes()"
case smartcontract.StringType:
return `func (item stackitem.Item) (string, error) {
b, err := item.TryBytes()
if err != nil {
return "", err
}
if !utf8.Valid(b) {
return "", errors.New("not a UTF-8 string")
}
return string(b), nil
} (` + v + `)`
case smartcontract.Hash160Type:
return `func (item stackitem.Item) (util.Uint160, error) {
b, err := item.TryBytes()
if err != nil {
return util.Uint160{}, err
}
u, err := util.Uint160DecodeBytesBE(b)
if err != nil {
return util.Uint160{}, err
}
return u, nil
} (` + v + `)`
case smartcontract.Hash256Type:
return `func (item stackitem.Item) (util.Uint256, error) {
b, err := item.TryBytes()
if err != nil {
return util.Uint256{}, err
}
u, err := util.Uint256DecodeBytesBE(b)
if err != nil {
return util.Uint256{}, err
}
return u, nil
} (` + v + `)`
case smartcontract.PublicKeyType:
return `func (item stackitem.Item) (*keys.PublicKey, error) {
b, err := item.TryBytes()
if err != nil {
return nil, err
}
k, err := keys.NewPublicKeyFromBytes(b, elliptic.P256())
if err != nil {
return nil, err
}
return k, nil
} (` + v + `)`
case smartcontract.ArrayType:
if len(et.Name) > 0 {
return "itemTo" + toTypeName(et.Name) + "(" + v + ", nil)"
} else if et.Value != nil {
at, _ := extendedTypeToGo(et, nil)
return `func (item stackitem.Item) (` + at + `, error) {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return nil, errors.New("not an array")
}
res := make(` + at + `, len(arr))
for i := range res {
res[i], err = ` + addIndent(etTypeConverter(*et.Value, "arr[i]"), "\t\t") + `
if err != nil {
return nil, fmt.Errorf("item %d: %w", i, err)
}
}
return res, nil
} (` + v + `)`
}
return etTypeConverter(binding.ExtendedType{
Base: smartcontract.ArrayType,
Value: &binding.ExtendedType{
Base: smartcontract.AnyType,
},
}, v)
case smartcontract.MapType:
at, _ := extendedTypeToGo(et, nil)
return `func (item stackitem.Item) (` + at + `, error) {
m, ok := item.Value().([]stackitem.MapElement)
if !ok {
return nil, fmt.Errorf("%s is not a map", item.Type().String())
}
res := make(` + at + `)
for i := range m {
k, err := ` + addIndent(etTypeConverter(binding.ExtendedType{Base: et.Key}, "m[i].Key"), "\t\t") + `
if err != nil {
return nil, fmt.Errorf("key %d: %w", i, err)
}
v, err := ` + addIndent(etTypeConverter(*et.Value, "m[i].Value"), "\t\t") + `
if err != nil {
return nil, fmt.Errorf("value %d: %w", i, err)
}
res[k] = v
}
return res, nil
} (` + v + `)`
case smartcontract.InteropInterfaceType:
return "item.Value(), nil"
case smartcontract.VoidType:
return ""
}
panic("unreachable")
}
func scTypeToGo(name string, typ smartcontract.ParamType, cfg *binding.Config) (string, string) {
et, ok := cfg.Types[name]
if !ok {
et = binding.ExtendedType{Base: typ}
}
return extendedTypeToGo(et, cfg.NamedTypes)
} }
func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]struct{}) ContractTmpl { func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]struct{}) ContractTmpl {
@ -359,7 +528,14 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
for i := 0; i < len(ctr.Methods); i++ { for i := 0; i < len(ctr.Methods); i++ {
abim := cfg.Manifest.ABI.GetMethod(ctr.Methods[i].NameABI, len(ctr.Methods[i].Arguments)) abim := cfg.Manifest.ABI.GetMethod(ctr.Methods[i].NameABI, len(ctr.Methods[i].Arguments))
if abim.Safe { if abim.Safe {
ctr.SafeMethods = append(ctr.SafeMethods, ctr.Methods[i]) ctr.SafeMethods = append(ctr.SafeMethods, SafeMethodTmpl{MethodTmpl: ctr.Methods[i]})
et, ok := cfg.Types[abim.Name]
if ok {
ctr.SafeMethods[len(ctr.SafeMethods)-1].ExtendedReturn = et
if abim.ReturnType == smartcontract.ArrayType && len(et.Name) > 0 {
ctr.SafeMethods[len(ctr.SafeMethods)-1].ItemTo = cutPointer(ctr.Methods[i].ReturnType)
}
}
ctr.Methods = append(ctr.Methods[:i], ctr.Methods[i+1:]...) ctr.Methods = append(ctr.Methods[:i], ctr.Methods[i+1:]...)
i-- i--
} else { } else {
@ -369,7 +545,13 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
} }
} }
} }
// We're misusing CallFlag field for function name here. for _, et := range cfg.NamedTypes {
addETImports(et, ctr.NamedTypes, imports)
}
if len(cfg.NamedTypes) > 0 {
imports["errors"] = struct{}{}
}
for i := range ctr.SafeMethods { for i := range ctr.SafeMethods {
switch ctr.SafeMethods[i].ReturnType { switch ctr.SafeMethods[i].ReturnType {
case "interface{}": case "interface{}":
@ -379,47 +561,50 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{} imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{}
imports["github.com/nspcc-dev/neo-go/pkg/neorpc/result"] = struct{}{} imports["github.com/nspcc-dev/neo-go/pkg/neorpc/result"] = struct{}{}
ctr.SafeMethods[i].ReturnType = "uuid.UUID, result.Iterator" ctr.SafeMethods[i].ReturnType = "uuid.UUID, result.Iterator"
ctr.SafeMethods[i].CallFlag = "SessionIterator" ctr.SafeMethods[i].Unwrapper = "SessionIterator"
ctr.HasIterator = true ctr.HasIterator = true
} else { } else {
imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{} imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{}
ctr.SafeMethods[i].ReturnType = "stackitem.Item" ctr.SafeMethods[i].ReturnType = "stackitem.Item"
ctr.SafeMethods[i].CallFlag = "Item" ctr.SafeMethods[i].Unwrapper = "Item"
} }
case "bool": case "bool":
ctr.SafeMethods[i].CallFlag = "Bool" ctr.SafeMethods[i].Unwrapper = "Bool"
case "*big.Int": case "*big.Int":
ctr.SafeMethods[i].CallFlag = "BigInt" ctr.SafeMethods[i].Unwrapper = "BigInt"
case "string": case "string":
ctr.SafeMethods[i].CallFlag = "UTF8String" ctr.SafeMethods[i].Unwrapper = "UTF8String"
case "util.Uint160": case "util.Uint160":
ctr.SafeMethods[i].CallFlag = "Uint160" ctr.SafeMethods[i].Unwrapper = "Uint160"
case "util.Uint256": case "util.Uint256":
ctr.SafeMethods[i].CallFlag = "Uint256" ctr.SafeMethods[i].Unwrapper = "Uint256"
case "*keys.PublicKey": case "*keys.PublicKey":
ctr.SafeMethods[i].CallFlag = "PublicKey" ctr.SafeMethods[i].Unwrapper = "PublicKey"
case "[]byte": case "[]byte":
ctr.SafeMethods[i].CallFlag = "Bytes" ctr.SafeMethods[i].Unwrapper = "Bytes"
case "[]interface{}": case "[]interface{}":
imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{} imports["github.com/nspcc-dev/neo-go/pkg/vm/stackitem"] = struct{}{}
ctr.SafeMethods[i].ReturnType = "[]stackitem.Item" ctr.SafeMethods[i].ReturnType = "[]stackitem.Item"
ctr.SafeMethods[i].CallFlag = "Array" ctr.SafeMethods[i].Unwrapper = "Array"
case "*stackitem.Map": case "*stackitem.Map":
ctr.SafeMethods[i].CallFlag = "Map" ctr.SafeMethods[i].Unwrapper = "Map"
case "[]bool": case "[]bool":
ctr.SafeMethods[i].CallFlag = "ArrayOfBools" ctr.SafeMethods[i].Unwrapper = "ArrayOfBools"
case "[]*big.Int": case "[]*big.Int":
ctr.SafeMethods[i].CallFlag = "ArrayOfBigInts" ctr.SafeMethods[i].Unwrapper = "ArrayOfBigInts"
case "[][]byte": case "[][]byte":
ctr.SafeMethods[i].CallFlag = "ArrayOfBytes" ctr.SafeMethods[i].Unwrapper = "ArrayOfBytes"
case "[]string": case "[]string":
ctr.SafeMethods[i].CallFlag = "ArrayOfUTF8Strings" ctr.SafeMethods[i].Unwrapper = "ArrayOfUTF8Strings"
case "[]util.Uint160": case "[]util.Uint160":
ctr.SafeMethods[i].CallFlag = "ArrayOfUint160" ctr.SafeMethods[i].Unwrapper = "ArrayOfUint160"
case "[]util.Uint256": case "[]util.Uint256":
ctr.SafeMethods[i].CallFlag = "ArrayOfUint256" ctr.SafeMethods[i].Unwrapper = "ArrayOfUint256"
case "keys.PublicKeys": case "keys.PublicKeys":
ctr.SafeMethods[i].CallFlag = "ArrayOfPublicKeys" ctr.SafeMethods[i].Unwrapper = "ArrayOfPublicKeys"
default:
addETImports(ctr.SafeMethods[i].ExtendedReturn, ctr.NamedTypes, imports)
ctr.SafeMethods[i].Unwrapper = "Item"
} }
} }
@ -446,3 +631,52 @@ func scTemplateToRPC(cfg binding.Config, ctr ContractTmpl, imports map[string]st
sort.Strings(ctr.Imports) sort.Strings(ctr.Imports)
return ctr return ctr
} }
func addETImports(et binding.ExtendedType, named map[string]binding.ExtendedType, imports map[string]struct{}) {
_, pkg := extendedTypeToGo(et, named)
if pkg != "" {
imports[pkg] = struct{}{}
}
// Additional packages used during decoding.
switch et.Base {
case smartcontract.StringType:
imports["unicode/utf8"] = struct{}{}
imports["errors"] = struct{}{}
case smartcontract.PublicKeyType:
imports["crypto/elliptic"] = struct{}{}
case smartcontract.MapType:
imports["fmt"] = struct{}{}
case smartcontract.ArrayType:
imports["errors"] = struct{}{}
imports["fmt"] = struct{}{}
}
if et.Value != nil {
addETImports(*et.Value, named, imports)
}
if et.Base == smartcontract.MapType {
addETImports(binding.ExtendedType{Base: et.Key}, named, imports)
}
for i := range et.Fields {
addETImports(et.Fields[i].ExtendedType, named, imports)
}
}
func cutPointer(s string) string {
if s[0] == '*' {
return s[1:]
}
return s
}
func toTypeName(s string) string {
return strings.Map(func(c rune) rune {
if c == '.' {
return -1
}
return c
}, strings.ToUpper(s[0:1])+s[1:])
}
func addIndent(str string, ind string) string {
return strings.ReplaceAll(str, "\n", "\n"+ind)
}