From f155a7f161035f2dad7291e04b9775430c7b07f6 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Tue, 9 Aug 2022 15:18:16 +0300 Subject: [PATCH] rpcclient: move result processing code into unwrap package Which will be reused by upper-layer packages. It can be extended with more types in future. --- pkg/rpcclient/helper.go | 174 -------------------- pkg/rpcclient/native.go | 98 ++++++------ pkg/rpcclient/nep.go | 45 +----- pkg/rpcclient/nep11.go | 113 ++----------- pkg/rpcclient/policy.go | 21 +-- pkg/rpcclient/rpc.go | 13 +- pkg/rpcclient/unwrap/unwrap.go | 228 +++++++++++++++++++++++++++ pkg/rpcclient/unwrap/unwrap_test.go | 235 ++++++++++++++++++++++++++++ 8 files changed, 539 insertions(+), 388 deletions(-) create mode 100644 pkg/rpcclient/unwrap/unwrap.go create mode 100644 pkg/rpcclient/unwrap/unwrap_test.go diff --git a/pkg/rpcclient/helper.go b/pkg/rpcclient/helper.go index 10656b289..587816c93 100644 --- a/pkg/rpcclient/helper.go +++ b/pkg/rpcclient/helper.go @@ -1,105 +1,15 @@ package rpcclient import ( - "crypto/elliptic" - "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/core/transaction" - "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" - "github.com/nspcc-dev/neo-go/pkg/rpcclient/nns" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) -// getInvocationError returns an error in case of bad VM state or an empty stack. -func getInvocationError(result *result.Invoke) error { - if result.State != "HALT" { - return fmt.Errorf("invocation failed: %s", result.FaultException) - } - if len(result.Stack) == 0 { - return errors.New("result stack is empty") - } - return nil -} - -// topBoolFromStack returns the top boolean value from the stack. -func topBoolFromStack(st []stackitem.Item) (bool, error) { - index := len(st) - 1 // top stack element is last in the array - result, ok := st[index].Value().(bool) - if !ok { - return false, fmt.Errorf("invalid stack item type: %s", st[index].Type()) - } - return result, nil -} - -// topIntFromStack returns the top integer value from the stack. -func topIntFromStack(st []stackitem.Item) (int64, error) { - index := len(st) - 1 // top stack element is last in the array - bi, err := st[index].TryInteger() - if err != nil { - return 0, err - } - return bi.Int64(), nil -} - -// topPublicKeysFromStack returns the top array of public keys from the stack. -func topPublicKeysFromStack(st []stackitem.Item) (keys.PublicKeys, error) { - index := len(st) - 1 // top stack element is last in the array - var ( - pks keys.PublicKeys - err error - ) - items, ok := st[index].Value().([]stackitem.Item) - if !ok { - return nil, fmt.Errorf("invalid stack item type: %s", st[index].Type()) - } - pks = make(keys.PublicKeys, len(items)) - for i, item := range items { - val, ok := item.Value().([]byte) - if !ok { - return nil, fmt.Errorf("invalid array element #%d: %s", i, item.Type()) - } - pks[i], err = keys.NewPublicKeyFromBytes(val, elliptic.P256()) - if err != nil { - return nil, err - } - } - return pks, nil -} - -// top string from stack returns the top string from the stack. -func topStringFromStack(st []stackitem.Item) (string, error) { - index := len(st) - 1 // top stack element is last in the array - bs, err := st[index].TryBytes() - if err != nil { - return "", err - } - return string(bs), nil -} - -// topUint160FromStack returns the top util.Uint160 from the stack. -func topUint160FromStack(st []stackitem.Item) (util.Uint160, error) { - index := len(st) - 1 // top stack element is last in the array - bs, err := st[index].TryBytes() - if err != nil { - return util.Uint160{}, err - } - return util.Uint160DecodeBytesBE(bs) -} - -// topMapFromStack returns the top stackitem.Map from the stack. -func topMapFromStack(st []stackitem.Item) (*stackitem.Map, error) { - index := len(st) - 1 // top stack element is last in the array - if t := st[index].Type(); t != stackitem.MapT { - return nil, fmt.Errorf("invalid return stackitem type: %s", t.String()) - } - return st[index].(*stackitem.Map), nil -} - // InvokeAndPackIteratorResults creates a script containing System.Contract.Call // of the specified contract with the specified arguments. It assumes that the // specified operation will return iterator. The script traverses the resulting @@ -132,87 +42,3 @@ func (c *Client) InvokeAndPackIteratorResults(contract util.Uint160, operation s } return c.InvokeScript(bytes, signers) } - -// topIterableFromStack returns the list of elements of `resultItemType` type from the top element -// of the provided stack. The top element is expected to be an Array, otherwise an error is returned. -func topIterableFromStack(st []stackitem.Item, resultItemType interface{}) ([]interface{}, error) { - index := len(st) - 1 // top stack element is the last in the array - if t := st[index].Type(); t != stackitem.ArrayT { - return nil, fmt.Errorf("invalid return stackitem type: %s (Array expected)", t.String()) - } - items, ok := st[index].Value().([]stackitem.Item) - if !ok { - return nil, fmt.Errorf("failed to deserialize iterable from Array stackitem: invalid value type (Array expected)") - } - result := make([]interface{}, len(items)) - for i := range items { - switch resultItemType.(type) { - case []byte: - bytes, err := items[i].TryBytes() - if err != nil { - return nil, fmt.Errorf("failed to deserialize []byte from stackitem #%d: %w", i, err) - } - result[i] = bytes - case string: - bytes, err := items[i].TryBytes() - if err != nil { - return nil, fmt.Errorf("failed to deserialize string from stackitem #%d: %w", i, err) - } - result[i] = string(bytes) - case util.Uint160: - bytes, err := items[i].TryBytes() - if err != nil { - return nil, fmt.Errorf("failed to deserialize uint160 from stackitem #%d: %w", i, err) - } - result[i], err = util.Uint160DecodeBytesBE(bytes) - if err != nil { - return nil, fmt.Errorf("failed to decode uint160 from stackitem #%d: %w", i, err) - } - case nns.RecordState: - rs, ok := items[i].Value().([]stackitem.Item) - if !ok { - return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: not a struct", i) - } - if len(rs) != 3 { - return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: wrong number of elements", i) - } - name, err := rs[0].TryBytes() - if err != nil { - return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: %w", i, err) - } - typ, err := rs[1].TryInteger() - if err != nil { - return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: %w", i, err) - } - data, err := rs[2].TryBytes() - if err != nil { - return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: %w", i, err) - } - u64Typ := typ.Uint64() - if !typ.IsUint64() || u64Typ > 255 { - return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: bad type", i) - } - result[i] = nns.RecordState{ - Name: string(name), - Type: nns.RecordType(u64Typ), - Data: string(data), - } - default: - return nil, errors.New("unsupported iterable type") - } - } - return result, nil -} - -// topIteratorFromStack returns the top Iterator from the stack. -func topIteratorFromStack(st []stackitem.Item) (result.Iterator, error) { - index := len(st) - 1 // top stack element is the last in the array - if t := st[index].Type(); t != stackitem.InteropT { - return result.Iterator{}, fmt.Errorf("expected InteropInterface on stack, got %s", t) - } - iter, ok := st[index].Value().(result.Iterator) - if !ok { - return result.Iterator{}, fmt.Errorf("failed to deserialize iterable from interop stackitem: invalid value type (Iterator expected)") - } - return iter, nil -} diff --git a/pkg/rpcclient/native.go b/pkg/rpcclient/native.go index 8d1741f43..609f353fa 100644 --- a/pkg/rpcclient/native.go +++ b/pkg/rpcclient/native.go @@ -3,6 +3,7 @@ package rpcclient // Various non-policy things from native contracts. import ( + "crypto/elliptic" "errors" "fmt" @@ -13,7 +14,9 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/rpcclient/nns" + "github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) // GetOraclePrice invokes `getPrice` method on a native Oracle contract. @@ -54,15 +57,22 @@ func (c *Client) GetDesignatedByRole(role noderoles.Role, index uint32) (keys.Pu if err != nil { return nil, fmt.Errorf("failed to get native RoleManagement hash: %w", err) } - result, err := c.reader.Call(rmHash, "getDesignatedByRole", int64(role), index) + arr, err := unwrap.Array(c.reader.Call(rmHash, "getDesignatedByRole", int64(role), index)) if err != nil { return nil, err } - err = getInvocationError(result) - if err != nil { - return nil, fmt.Errorf("`getDesignatedByRole`: %w", err) + pks := make(keys.PublicKeys, len(arr)) + for i, item := range arr { + val, err := item.TryBytes() + if err != nil { + return nil, fmt.Errorf("invalid array element #%d: %s", i, item.Type()) + } + pks[i], err = keys.NewPublicKeyFromBytes(val, elliptic.P256()) + if err != nil { + return nil, err + } } - return topPublicKeysFromStack(result.Stack) + return pks, nil } // NNSResolve invokes `resolve` method on a NameService contract with the specified hash. @@ -70,28 +80,12 @@ func (c *Client) NNSResolve(nnsHash util.Uint160, name string, typ nns.RecordTyp if typ == nns.CNAME { return "", errors.New("can't resolve CNAME record type") } - result, err := c.reader.Call(nnsHash, "resolve", name, int64(typ)) - if err != nil { - return "", err - } - err = getInvocationError(result) - if err != nil { - return "", fmt.Errorf("`resolve`: %w", err) - } - return topStringFromStack(result.Stack) + return unwrap.UTF8String(c.reader.Call(nnsHash, "resolve", name, int64(typ))) } // NNSIsAvailable invokes `isAvailable` method on a NeoNameService contract with the specified hash. func (c *Client) NNSIsAvailable(nnsHash util.Uint160, name string) (bool, error) { - result, err := c.reader.Call(nnsHash, "isAvailable", name) - if err != nil { - return false, err - } - err = getInvocationError(result) - if err != nil { - return false, fmt.Errorf("`isAvailable`: %w", err) - } - return topBoolFromStack(result.Stack) + return unwrap.Bool(c.reader.Call(nnsHash, "isAvailable", name)) } // NNSGetAllRecords returns iterator over records for a given name from NNS service. @@ -100,17 +94,7 @@ func (c *Client) NNSIsAvailable(nnsHash util.Uint160, name string) (bool, error) // TerminateSession to terminate opened iterator session. See TraverseIterator and // TerminateSession documentation for more details. func (c *Client) NNSGetAllRecords(nnsHash util.Uint160, name string) (uuid.UUID, result.Iterator, error) { - res, err := c.reader.Call(nnsHash, "getAllRecords", name) - if err != nil { - return uuid.UUID{}, result.Iterator{}, err - } - err = getInvocationError(res) - if err != nil { - return uuid.UUID{}, result.Iterator{}, err - } - - iter, err := topIteratorFromStack(res.Stack) - return res.Session, iter, err + return unwrap.SessionIterator(c.reader.Call(nnsHash, "getAllRecords", name)) } // NNSUnpackedGetAllRecords returns a set of records for a given name from NNS service @@ -118,24 +102,42 @@ func (c *Client) NNSGetAllRecords(nnsHash util.Uint160, name string) (uuid.UUID, // that no iterator session is used to retrieve values from iterator. Instead, unpacking // VM script is created and invoked via `invokescript` JSON-RPC call. func (c *Client) NNSUnpackedGetAllRecords(nnsHash util.Uint160, name string) ([]nns.RecordState, error) { - result, err := c.reader.CallAndExpandIterator(nnsHash, "getAllRecords", config.DefaultMaxIteratorResultItems, name) + arr, err := unwrap.Array(c.reader.CallAndExpandIterator(nnsHash, "getAllRecords", config.DefaultMaxIteratorResultItems, name)) if err != nil { return nil, err } - err = getInvocationError(result) - if err != nil { - return nil, err + res := make([]nns.RecordState, len(arr)) + for i := range arr { + rs, ok := arr[i].Value().([]stackitem.Item) + if !ok { + return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: not a struct", i) + } + if len(rs) != 3 { + return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: wrong number of elements", i) + } + name, err := rs[0].TryBytes() + if err != nil { + return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: %w", i, err) + } + typ, err := rs[1].TryInteger() + if err != nil { + return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: %w", i, err) + } + data, err := rs[2].TryBytes() + if err != nil { + return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: %w", i, err) + } + u64Typ := typ.Uint64() + if !typ.IsUint64() || u64Typ > 255 { + return nil, fmt.Errorf("failed to decode RecordState from stackitem #%d: bad type", i) + } + res[i] = nns.RecordState{ + Name: string(name), + Type: nns.RecordType(u64Typ), + Data: string(data), + } } - - arr, err := topIterableFromStack(result.Stack, nns.RecordState{}) - if err != nil { - return nil, fmt.Errorf("failed to get token IDs from stack: %w", err) - } - rss := make([]nns.RecordState, len(arr)) - for i := range rss { - rss[i] = arr[i].(nns.RecordState) - } - return rss, nil + return res, nil } // GetNotaryServiceFeePerKey returns a reward per notary request key for the designated diff --git a/pkg/rpcclient/nep.go b/pkg/rpcclient/nep.go index 7e0a5abbf..4b2497315 100644 --- a/pkg/rpcclient/nep.go +++ b/pkg/rpcclient/nep.go @@ -3,50 +3,24 @@ package rpcclient import ( "fmt" + "github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/wallet" ) // nepDecimals invokes `decimals` NEP* method on the specified contract. func (c *Client) nepDecimals(tokenHash util.Uint160) (int64, error) { - result, err := c.reader.Call(tokenHash, "decimals") - if err != nil { - return 0, err - } - err = getInvocationError(result) - if err != nil { - return 0, err - } - - return topIntFromStack(result.Stack) + return unwrap.Int64(c.reader.Call(tokenHash, "decimals")) } // nepSymbol invokes `symbol` NEP* method on the specified contract. func (c *Client) nepSymbol(tokenHash util.Uint160) (string, error) { - result, err := c.reader.Call(tokenHash, "symbol") - if err != nil { - return "", err - } - err = getInvocationError(result) - if err != nil { - return "", err - } - - return topStringFromStack(result.Stack) + return unwrap.PrintableASCIIString(c.reader.Call(tokenHash, "symbol")) } // nepTotalSupply invokes `totalSupply` NEP* method on the specified contract. func (c *Client) nepTotalSupply(tokenHash util.Uint160) (int64, error) { - result, err := c.reader.Call(tokenHash, "totalSupply") - if err != nil { - return 0, err - } - err = getInvocationError(result) - if err != nil { - return 0, err - } - - return topIntFromStack(result.Stack) + return unwrap.Int64(c.reader.Call(tokenHash, "totalSupply")) } // nepBalanceOf invokes `balanceOf` NEP* method on the specified contract. @@ -55,16 +29,7 @@ func (c *Client) nepBalanceOf(tokenHash, acc util.Uint160, tokenID []byte) (int6 if tokenID != nil { params = append(params, tokenID) } - result, err := c.reader.Call(tokenHash, "balanceOf", params...) - if err != nil { - return 0, err - } - err = getInvocationError(result) - if err != nil { - return 0, err - } - - return topIntFromStack(result.Stack) + return unwrap.Int64(c.reader.Call(tokenHash, "balanceOf", params...)) } // nepTokenInfo returns full NEP* token info. diff --git a/pkg/rpcclient/nep11.go b/pkg/rpcclient/nep11.go index 4865ea78f..9ae489b14 100644 --- a/pkg/rpcclient/nep11.go +++ b/pkg/rpcclient/nep11.go @@ -8,6 +8,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" + "github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" @@ -85,16 +86,7 @@ func (c *Client) CreateNEP11TransferTx(acc *wallet.Account, tokenHash util.Uint1 // traverse iterator values or TerminateSession to terminate opened iterator // session. See TraverseIterator and TerminateSession documentation for more details. func (c *Client) NEP11TokensOf(tokenHash util.Uint160, owner util.Uint160) (uuid.UUID, result.Iterator, error) { - res, err := c.reader.Call(tokenHash, "tokensOf", owner) - if err != nil { - return uuid.UUID{}, result.Iterator{}, err - } - err = getInvocationError(res) - if err != nil { - return uuid.UUID{}, result.Iterator{}, err - } - iter, err := topIteratorFromStack(res.Stack) - return res.Session, iter, err + return unwrap.SessionIterator(c.reader.Call(tokenHash, "tokensOf", owner)) } // NEP11UnpackedTokensOf returns an array of token IDs for the specified owner of the specified NFT token @@ -102,24 +94,7 @@ func (c *Client) NEP11TokensOf(tokenHash util.Uint160, owner util.Uint160) (uuid // is used to retrieve values from iterator. Instead, unpacking VM script is created and invoked via // `invokescript` JSON-RPC call. func (c *Client) NEP11UnpackedTokensOf(tokenHash util.Uint160, owner util.Uint160) ([][]byte, error) { - result, err := c.reader.CallAndExpandIterator(tokenHash, "tokensOf", config.DefaultMaxIteratorResultItems, owner) - if err != nil { - return nil, err - } - err = getInvocationError(result) - if err != nil { - return nil, err - } - - arr, err := topIterableFromStack(result.Stack, []byte{}) - if err != nil { - return nil, fmt.Errorf("failed to get token IDs from stack: %w", err) - } - ids := make([][]byte, len(arr)) - for i := range ids { - ids[i] = arr[i].([]byte) - } - return ids, nil + return unwrap.ArrayOfBytes(c.reader.CallAndExpandIterator(tokenHash, "tokensOf", config.DefaultMaxIteratorResultItems, owner)) } // Non-divisible NFT methods section start. @@ -127,16 +102,7 @@ func (c *Client) NEP11UnpackedTokensOf(tokenHash util.Uint160, owner util.Uint16 // NEP11NDOwnerOf invokes `ownerOf` non-divisible NEP-11 method with the // specified token ID on the specified contract. func (c *Client) NEP11NDOwnerOf(tokenHash util.Uint160, tokenID []byte) (util.Uint160, error) { - result, err := c.reader.Call(tokenHash, "ownerOf", tokenID) - if err != nil { - return util.Uint160{}, err - } - err = getInvocationError(result) - if err != nil { - return util.Uint160{}, err - } - - return topUint160FromStack(result.Stack) + return unwrap.Uint160(c.reader.Call(tokenHash, "ownerOf", tokenID)) } // Non-divisible NFT methods section end. @@ -172,17 +138,7 @@ func (c *Client) NEP11DBalanceOf(tokenHash, owner util.Uint160, tokenID []byte) // method to traverse iterator values or TerminateSession to terminate opened iterator session. See // TraverseIterator and TerminateSession documentation for more details. func (c *Client) NEP11DOwnerOf(tokenHash util.Uint160, tokenID []byte) (uuid.UUID, result.Iterator, error) { - res, err := c.reader.Call(tokenHash, "ownerOf", tokenID) - sessID := res.Session - if err != nil { - return sessID, result.Iterator{}, err - } - err = getInvocationError(res) - if err != nil { - return sessID, result.Iterator{}, err - } - arr, err := topIteratorFromStack(res.Stack) - return sessID, arr, err + return unwrap.SessionIterator(c.reader.Call(tokenHash, "ownerOf", tokenID)) } // NEP11DUnpackedOwnerOf returns list of the specified NEP-11 divisible token owners @@ -190,22 +146,16 @@ func (c *Client) NEP11DOwnerOf(tokenHash util.Uint160, tokenID []byte) (uuid.UUI // iterator session is used to retrieve values from iterator. Instead, unpacking VM // script is created and invoked via `invokescript` JSON-RPC call. func (c *Client) NEP11DUnpackedOwnerOf(tokenHash util.Uint160, tokenID []byte) ([]util.Uint160, error) { - result, err := c.reader.CallAndExpandIterator(tokenHash, "ownerOf", config.DefaultMaxIteratorResultItems, tokenID) + arr, err := unwrap.ArrayOfBytes(c.reader.CallAndExpandIterator(tokenHash, "ownerOf", config.DefaultMaxIteratorResultItems, tokenID)) if err != nil { return nil, err } - err = getInvocationError(result) - if err != nil { - return nil, err - } - - arr, err := topIterableFromStack(result.Stack, util.Uint160{}) - if err != nil { - return nil, fmt.Errorf("failed to get token IDs from stack: %w", err) - } owners := make([]util.Uint160, len(arr)) - for i := range owners { - owners[i] = arr[i].(util.Uint160) + for i := range arr { + owners[i], err = util.Uint160DecodeBytesBE(arr[i]) + if err != nil { + return nil, fmt.Errorf("not a Uint160 at %d: %w", i, err) + } } return owners, nil } @@ -217,16 +167,7 @@ func (c *Client) NEP11DUnpackedOwnerOf(tokenHash util.Uint160, tokenID []byte) ( // NEP11Properties invokes `properties` optional NEP-11 method on the // specified contract. func (c *Client) NEP11Properties(tokenHash util.Uint160, tokenID []byte) (*stackitem.Map, error) { - result, err := c.reader.Call(tokenHash, "properties", tokenID) - if err != nil { - return nil, err - } - err = getInvocationError(result) - if err != nil { - return nil, err - } - - return topMapFromStack(result.Stack) + return unwrap.Map(c.reader.Call(tokenHash, "properties", tokenID)) } // NEP11Tokens returns iterator over the tokens minted by the contract. First return @@ -235,16 +176,7 @@ func (c *Client) NEP11Properties(tokenHash util.Uint160, tokenID []byte) (*stack // TerminateSession to terminate opened iterator session. See TraverseIterator and // TerminateSession documentation for more details. func (c *Client) NEP11Tokens(tokenHash util.Uint160) (uuid.UUID, result.Iterator, error) { - res, err := c.reader.Call(tokenHash, "tokens") - if err != nil { - return uuid.UUID{}, result.Iterator{}, err - } - err = getInvocationError(res) - if err != nil { - return uuid.UUID{}, result.Iterator{}, err - } - iter, err := topIteratorFromStack(res.Stack) - return res.Session, iter, err + return unwrap.SessionIterator(c.reader.Call(tokenHash, "tokens")) } // NEP11UnpackedTokens returns list of the tokens minted by the contract @@ -252,24 +184,7 @@ func (c *Client) NEP11Tokens(tokenHash util.Uint160) (uuid.UUID, result.Iterator // iterator session is used to retrieve values from iterator. Instead, unpacking // VM script is created and invoked via `invokescript` JSON-RPC call. func (c *Client) NEP11UnpackedTokens(tokenHash util.Uint160) ([][]byte, error) { - result, err := c.reader.CallAndExpandIterator(tokenHash, "tokens", config.DefaultMaxIteratorResultItems) - if err != nil { - return nil, err - } - err = getInvocationError(result) - if err != nil { - return nil, err - } - - arr, err := topIterableFromStack(result.Stack, []byte{}) - if err != nil { - return nil, fmt.Errorf("failed to get token IDs from stack: %w", err) - } - tokens := make([][]byte, len(arr)) - for i := range tokens { - tokens[i] = arr[i].([]byte) - } - return tokens, nil + return unwrap.ArrayOfBytes(c.reader.CallAndExpandIterator(tokenHash, "tokens", config.DefaultMaxIteratorResultItems)) } // Optional NFT methods section end. diff --git a/pkg/rpcclient/policy.go b/pkg/rpcclient/policy.go index b2b6372eb..e6a7f1733 100644 --- a/pkg/rpcclient/policy.go +++ b/pkg/rpcclient/policy.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/nspcc-dev/neo-go/pkg/core/native/nativenames" + "github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -41,15 +42,7 @@ func (c *Client) invokeNativePolicyMethod(operation string) (int64, error) { } func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int64, error) { - result, err := c.reader.Call(hash, operation) - if err != nil { - return 0, err - } - err = getInvocationError(result) - if err != nil { - return 0, fmt.Errorf("failed to invoke %s method of native contract %s: %w", operation, hash.StringLE(), err) - } - return topIntFromStack(result.Stack) + return unwrap.Int64(c.reader.Call(hash, operation)) } // IsBlocked invokes `isBlocked` method on native Policy contract. @@ -58,13 +51,5 @@ func (c *Client) IsBlocked(hash util.Uint160) (bool, error) { if err != nil { return false, fmt.Errorf("failed to get native Policy hash: %w", err) } - result, err := c.reader.Call(policyHash, "isBlocked", hash) - if err != nil { - return false, err - } - err = getInvocationError(result) - if err != nil { - return false, fmt.Errorf("failed to check if account is blocked: %w", err) - } - return topBoolFromStack(result.Stack) + return unwrap.Bool(c.reader.Call(policyHash, "isBlocked", hash)) } diff --git a/pkg/rpcclient/rpc.go b/pkg/rpcclient/rpc.go index 84333fb9d..6cdc2c07e 100644 --- a/pkg/rpcclient/rpc.go +++ b/pkg/rpcclient/rpc.go @@ -23,6 +23,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/neorpc" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/network/payload" + "github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" @@ -1064,17 +1065,11 @@ func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs if err != nil { return fmt.Errorf("failed to invoke verify: %w", err) } - if res.State != "HALT" { - return fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException) - } - if l := len(res.Stack); l != 1 { - return fmt.Errorf("result stack length should be equal to 1, got %d", l) - } - r, err := topIntFromStack(res.Stack) + r, err := unwrap.Bool(res, err) if err != nil { - return fmt.Errorf("signer #%d: failed to get `verify` result from stack: %w", i, err) + return fmt.Errorf("signer #%d: %w", i, err) } - if r == 0 { + if !r { return fmt.Errorf("signer #%d: `verify` returned `false`", i) } tx.NetworkFee += res.GasConsumed diff --git a/pkg/rpcclient/unwrap/unwrap.go b/pkg/rpcclient/unwrap/unwrap.go new file mode 100644 index 000000000..25d7961d5 --- /dev/null +++ b/pkg/rpcclient/unwrap/unwrap.go @@ -0,0 +1,228 @@ +/* +Package unwrap provides a set of proxy methods to process invocation results. + +Functions implemented there are intended to be used as wrappers for other +functions that return (*result.Invoke, error) pair (of which there are many). +These functions will check for error, check for VM state, check the number +of results, cast them to appropriate type (if everything is OK) and then +return a result or error. They're mostly useful for other higher-level +contract-specific packages. +*/ +package unwrap + +import ( + "errors" + "fmt" + "math/big" + "unicode/utf8" + + "github.com/google/uuid" + "github.com/nspcc-dev/neo-go/pkg/neorpc/result" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" +) + +// BigInt expects correct execution (HALT state) with a single stack item +// returned. A big.Int is extracted from this item and returned. +func BigInt(r *result.Invoke, err error) (*big.Int, error) { + itm, err := getSingleItem(r, err) + if err != nil { + return nil, err + } + return itm.TryInteger() +} + +// Bool expects correct execution (HALT state) with a single stack item +// returned. A bool is extracted from this item and returned. +func Bool(r *result.Invoke, err error) (bool, error) { + itm, err := getSingleItem(r, err) + if err != nil { + return false, err + } + return itm.TryBool() +} + +// Int64 expects correct execution (HALT state) with a single stack item +// returned. An int64 is extracted from this item and returned. +func Int64(r *result.Invoke, err error) (int64, error) { + itm, err := getSingleItem(r, err) + if err != nil { + return 0, err + } + i, err := itm.TryInteger() + if err != nil { + return 0, err + } + if !i.IsInt64() { + return 0, errors.New("int64 overflow") + } + return i.Int64(), nil +} + +// LimitedInt64 is similar to Int64 except it allows to set minimum and maximum +// limits to be checked, so if it doesn't return an error the value is more than +// min and less than max. +func LimitedInt64(r *result.Invoke, err error, min int64, max int64) (int64, error) { + i, err := Int64(r, err) + if err != nil { + return 0, err + } + if i < min { + return 0, errors.New("too small value") + } + if i > max { + return 0, errors.New("too big value") + } + return i, nil +} + +// Bytes expects correct execution (HALT state) with a single stack item +// returned. A slice of bytes is extracted from this item and returned. +func Bytes(r *result.Invoke, err error) ([]byte, error) { + itm, err := getSingleItem(r, err) + if err != nil { + return nil, err + } + return itm.TryBytes() +} + +// UTF8String expects correct execution (HALT state) with a single stack item +// returned. A string is extracted from this item and checked for UTF-8 +// correctness, valid strings are then returned. +func UTF8String(r *result.Invoke, err error) (string, error) { + b, err := Bytes(r, err) + if err != nil { + return "", err + } + if !utf8.Valid(b) { + return "", errors.New("not a UTF-8 string") + } + return string(b), nil +} + +// PrintableASCIIString expects correct execution (HALT state) with a single +// stack item returned. A string is extracted from this item and checked to +// only contain ASCII characters in printable range, valid strings are then +// returned. +func PrintableASCIIString(r *result.Invoke, err error) (string, error) { + s, err := UTF8String(r, err) + if err != nil { + return "", err + } + for _, c := range s { + if c < 32 || c >= 127 { + return "", errors.New("not a printable ASCII string") + } + } + return s, nil +} + +// Uint160 expects correct execution (HALT state) with a single stack item +// returned. An util.Uint160 is extracted from this item and returned. +func Uint160(r *result.Invoke, err error) (util.Uint160, error) { + b, err := Bytes(r, err) + if err != nil { + return util.Uint160{}, err + } + return util.Uint160DecodeBytesBE(b) +} + +// Uint256 expects correct execution (HALT state) with a single stack item +// returned. An util.Uint256 is extracted from this item and returned. +func Uint256(r *result.Invoke, err error) (util.Uint256, error) { + b, err := Bytes(r, err) + if err != nil { + return util.Uint256{}, err + } + return util.Uint256DecodeBytesBE(b) +} + +// SessionIterator expects correct execution (HALT state) with a single stack +// item returned. If this item is an iterator it's returned to the caller along +// with the session ID. +func SessionIterator(r *result.Invoke, err error) (uuid.UUID, result.Iterator, error) { + itm, err := getSingleItem(r, err) + if err != nil { + return uuid.UUID{}, result.Iterator{}, err + } + if t := itm.Type(); t != stackitem.InteropT { + return uuid.UUID{}, result.Iterator{}, fmt.Errorf("expected InteropInterface, got %s", t) + } + iter, ok := itm.Value().(result.Iterator) + if !ok { + return uuid.UUID{}, result.Iterator{}, errors.New("the item is InteropInterface, but not an Iterator") + } + return r.Session, iter, nil +} + +// Array expects correct execution (HALT state) with a single array stack item +// returned. This item is returned to the caller. Notice that this function can +// be used for structures as well since they're also represented as slices of +// stack items (the number of them and their types are structure-specific). +func Array(r *result.Invoke, err error) ([]stackitem.Item, error) { + itm, err := getSingleItem(r, err) + if err != nil { + return nil, err + } + arr, ok := itm.Value().([]stackitem.Item) + if !ok { + return nil, errors.New("not an array") + } + return arr, nil +} + +// ArrayOfBytes checks the result for correct state (HALT) and then extracts a +// slice of byte slices from the returned stack item. +func ArrayOfBytes(r *result.Invoke, err error) ([][]byte, error) { + a, err := Array(r, err) + if err != nil { + return nil, err + } + res := make([][]byte, len(a)) + for i := range a { + b, err := a[i].TryBytes() + if err != nil { + return nil, fmt.Errorf("element %d is not a byte string: %w", i, err) + } + res[i] = b + } + return res, nil +} + +// Map expects correct execution (HALT state) with a single stack item +// returned. A stackitem.Map is extracted from this item and returned. +func Map(r *result.Invoke, err error) (*stackitem.Map, error) { + itm, err := getSingleItem(r, err) + if err != nil { + return nil, err + } + if t := itm.Type(); t != stackitem.MapT { + return nil, fmt.Errorf("%s is not a map", t.String()) + } + return itm.(*stackitem.Map), nil +} + +func checkResOK(r *result.Invoke, err error) error { + if err != nil { + return err + } + if r.State != vmstate.Halt.String() { + return fmt.Errorf("invocation failed: %s", r.FaultException) + } + return nil +} + +func getSingleItem(r *result.Invoke, err error) (stackitem.Item, error) { + err = checkResOK(r, err) + if err != nil { + return nil, err + } + if len(r.Stack) == 0 { + return nil, errors.New("result stack is empty") + } + if len(r.Stack) > 1 { + return nil, fmt.Errorf("too many (%d) result items", len(r.Stack)) + } + return r.Stack[0], nil +} diff --git a/pkg/rpcclient/unwrap/unwrap_test.go b/pkg/rpcclient/unwrap/unwrap_test.go new file mode 100644 index 000000000..e62383239 --- /dev/null +++ b/pkg/rpcclient/unwrap/unwrap_test.go @@ -0,0 +1,235 @@ +package unwrap + +import ( + "errors" + "math" + "math/big" + "testing" + + "github.com/google/uuid" + "github.com/nspcc-dev/neo-go/pkg/neorpc/result" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/stretchr/testify/require" +) + +func TestStdErrors(t *testing.T) { + funcs := []func(r *result.Invoke, err error) (interface{}, error){ + func(r *result.Invoke, err error) (interface{}, error) { + return BigInt(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return Bool(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return Int64(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return LimitedInt64(r, err, 0, 1) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return Bytes(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return UTF8String(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return PrintableASCIIString(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return Uint160(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return Uint256(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + _, _, err = SessionIterator(r, err) + return nil, err + }, + func(r *result.Invoke, err error) (interface{}, error) { + return Array(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return ArrayOfBytes(r, err) + }, + func(r *result.Invoke, err error) (interface{}, error) { + return Map(r, err) + }, + } + t.Run("error on input", func(t *testing.T) { + for _, f := range funcs { + _, err := f(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, errors.New("some")) + require.Error(t, err) + } + }) + + t.Run("FAULT state", func(t *testing.T) { + for _, f := range funcs { + _, err := f(&result.Invoke{State: "FAULT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil) + require.Error(t, err) + } + }) + t.Run("nothing returned", func(t *testing.T) { + for _, f := range funcs { + _, err := f(&result.Invoke{State: "HALT"}, errors.New("some")) + require.Error(t, err) + } + }) + t.Run("multiple return values", func(t *testing.T) { + for _, f := range funcs { + _, err := f(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42), stackitem.Make(42)}}, nil) + require.Error(t, err) + } + }) +} + +func TestBigInt(t *testing.T) { + _, err := BigInt(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make([]stackitem.Item{})}}, nil) + require.Error(t, err) + + i, err := BigInt(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil) + require.NoError(t, err) + require.Equal(t, big.NewInt(42), i) +} + +func TestBool(t *testing.T) { + _, err := Bool(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make("0x03c564ed28ba3d50beb1a52dcb751b929e1d747281566bd510363470be186bc0")}}, nil) + require.Error(t, err) + + b, err := Bool(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(true)}}, nil) + require.NoError(t, err) + require.True(t, b) +} + +func TestInt64(t *testing.T) { + _, err := Int64(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make("0x03c564ed28ba3d50beb1a52dcb751b929e1d747281566bd510363470be186bc0")}}, nil) + require.Error(t, err) + + _, err = Int64(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(uint64(math.MaxUint64))}}, nil) + require.Error(t, err) + + i, err := Int64(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil) + require.NoError(t, err) + require.Equal(t, int64(42), i) +} + +func TestLimitedInt64(t *testing.T) { + _, err := LimitedInt64(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make("0x03c564ed28ba3d50beb1a52dcb751b929e1d747281566bd510363470be186bc0")}}, nil, math.MinInt64, math.MaxInt64) + require.Error(t, err) + + _, err = LimitedInt64(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(uint64(math.MaxUint64))}}, nil, math.MinInt64, math.MaxInt64) + require.Error(t, err) + + _, err = LimitedInt64(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil, 128, 256) + require.Error(t, err) + + _, err = LimitedInt64(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil, 0, 40) + require.Error(t, err) + + i, err := LimitedInt64(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil, 0, 128) + require.NoError(t, err) + require.Equal(t, int64(42), i) +} + +func TestBytes(t *testing.T) { + _, err := Bytes(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make([]stackitem.Item{})}}, nil) + require.Error(t, err) + + b, err := Bytes(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make([]byte{1, 2, 3})}}, nil) + require.NoError(t, err) + require.Equal(t, []byte{1, 2, 3}, b) +} + +func TestUTF8String(t *testing.T) { + _, err := UTF8String(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make([]stackitem.Item{})}}, nil) + require.Error(t, err) + + _, err = UTF8String(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make("\xff")}}, nil) + require.Error(t, err) + + s, err := UTF8String(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make("value")}}, nil) + require.NoError(t, err) + require.Equal(t, "value", s) +} + +func TestPrintableASCIIString(t *testing.T) { + _, err := PrintableASCIIString(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make([]stackitem.Item{})}}, nil) + require.Error(t, err) + + _, err = PrintableASCIIString(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make("\xff")}}, nil) + require.Error(t, err) + + _, err = PrintableASCIIString(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make("\n\r")}}, nil) + require.Error(t, err) + + s, err := PrintableASCIIString(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make("value")}}, nil) + require.NoError(t, err) + require.Equal(t, "value", s) +} + +func TestUint160(t *testing.T) { + _, err := Uint160(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(util.Uint256{1, 2, 3}.BytesBE())}}, nil) + require.Error(t, err) + + u, err := Uint160(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(util.Uint160{1, 2, 3}.BytesBE())}}, nil) + require.NoError(t, err) + require.Equal(t, util.Uint160{1, 2, 3}, u) +} + +func TestUint256(t *testing.T) { + _, err := Uint256(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(util.Uint160{1, 2, 3}.BytesBE())}}, nil) + require.Error(t, err) + + u, err := Uint256(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(util.Uint256{1, 2, 3}.BytesBE())}}, nil) + require.NoError(t, err) + require.Equal(t, util.Uint256{1, 2, 3}, u) +} + +func TestSessionIterator(t *testing.T) { + _, _, err := SessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil) + require.Error(t, err) + + _, _, err = SessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.NewInterop(42)}}, nil) + require.Error(t, err) + + iid := uuid.New() + sid := uuid.New() + iter := result.Iterator{ID: &iid} + rs, ri, err := SessionIterator(&result.Invoke{Session: sid, State: "HALT", Stack: []stackitem.Item{stackitem.NewInterop(iter)}}, nil) + require.NoError(t, err) + require.Equal(t, sid, rs) + require.Equal(t, iter, ri) +} + +func TestArray(t *testing.T) { + _, err := Array(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil) + require.Error(t, err) + + a, err := Array(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make([]stackitem.Item{stackitem.Make(42)})}}, nil) + require.NoError(t, err) + require.Equal(t, 1, len(a)) + require.Equal(t, stackitem.Make(42), a[0]) +} + +func TestArrayOfBytes(t *testing.T) { + _, err := ArrayOfBytes(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil) + require.Error(t, err) + + _, err = ArrayOfBytes(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make([]stackitem.Item{stackitem.Make([]stackitem.Item{})})}}, nil) + require.Error(t, err) + + a, err := ArrayOfBytes(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make([]stackitem.Item{stackitem.Make([]byte("some"))})}}, nil) + require.NoError(t, err) + require.Equal(t, 1, len(a)) + require.Equal(t, []byte("some"), a[0]) +} + +func TestMap(t *testing.T) { + _, err := Map(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil) + require.Error(t, err) + + m, err := Map(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.NewMapWithValue([]stackitem.MapElement{{Key: stackitem.Make(42), Value: stackitem.Make("string")}})}}, nil) + require.NoError(t, err) + require.Equal(t, 1, m.Len()) + require.Equal(t, 0, m.Index(stackitem.Make(42))) +}