Merge pull request #2367 from nspcc-dev/rpc/thread-safe
rpc: take care of RPC clients
This commit is contained in:
commit
870fd024c9
13 changed files with 369 additions and 129 deletions
|
@ -682,7 +682,11 @@ func invokeWithArgs(ctx *cli.Context, acc *wallet.Account, wall *wallet.Wallet,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return sender, cli.NewExitError(fmt.Errorf("failed to create tx: %w", err), 1)
|
return sender, cli.NewExitError(fmt.Errorf("failed to create tx: %w", err), 1)
|
||||||
}
|
}
|
||||||
if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, out); err != nil {
|
m, err := c.GetNetwork()
|
||||||
|
if err != nil {
|
||||||
|
return sender, cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1)
|
||||||
|
}
|
||||||
|
if err := paramcontext.InitAndSave(m, tx, acc, out); err != nil {
|
||||||
return sender, cli.NewExitError(err, 1)
|
return sender, cli.NewExitError(err, 1)
|
||||||
}
|
}
|
||||||
fmt.Fprintln(ctx.App.Writer, tx.Hash().StringLE())
|
fmt.Fprintln(ctx.App.Writer, tx.Hash().StringLE())
|
||||||
|
|
|
@ -274,7 +274,11 @@ func signAndSendNEP11Transfer(ctx *cli.Context, c *client.Client, acc *wallet.Ac
|
||||||
tx.SystemFee += int64(sysgas)
|
tx.SystemFee += int64(sysgas)
|
||||||
|
|
||||||
if outFile := ctx.String("out"); outFile != "" {
|
if outFile := ctx.String("out"); outFile != "" {
|
||||||
if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, outFile); err != nil {
|
m, err := c.GetNetwork()
|
||||||
|
if err != nil {
|
||||||
|
return cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1)
|
||||||
|
}
|
||||||
|
if err := paramcontext.InitAndSave(m, tx, acc, outFile); err != nil {
|
||||||
return cli.NewExitError(err, 1)
|
return cli.NewExitError(err, 1)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -647,7 +647,11 @@ func signAndSendNEP17Transfer(ctx *cli.Context, c *client.Client, acc *wallet.Ac
|
||||||
tx.SystemFee += int64(sysgas)
|
tx.SystemFee += int64(sysgas)
|
||||||
|
|
||||||
if outFile := ctx.String("out"); outFile != "" {
|
if outFile := ctx.String("out"); outFile != "" {
|
||||||
if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, outFile); err != nil {
|
m, err := c.GetNetwork()
|
||||||
|
if err != nil {
|
||||||
|
return cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1)
|
||||||
|
}
|
||||||
|
if err := paramcontext.InitAndSave(m, tx, acc, outFile); err != nil {
|
||||||
return cli.NewExitError(err, 1)
|
return cli.NewExitError(err, 1)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
|
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
|
||||||
|
@ -16,6 +17,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -26,17 +28,26 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client represents the middleman for executing JSON RPC calls
|
// Client represents the middleman for executing JSON RPC calls
|
||||||
// to remote NEO RPC nodes.
|
// to remote NEO RPC nodes. Client is thread-safe and can be used from
|
||||||
|
// multiple goroutines.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
cli *http.Client
|
cli *http.Client
|
||||||
endpoint *url.URL
|
endpoint *url.URL
|
||||||
network netmode.Magic
|
ctx context.Context
|
||||||
stateRootInHeader bool
|
opts Options
|
||||||
initDone bool
|
requestF func(*request.Raw) (*response.Raw, error)
|
||||||
ctx context.Context
|
|
||||||
opts Options
|
cacheLock sync.RWMutex
|
||||||
requestF func(*request.Raw) (*response.Raw, error)
|
// cache stores RPC node related information client is bound to.
|
||||||
cache cache
|
// cache is mostly filled in during Init(), but can also be updated
|
||||||
|
// during regular Client lifecycle.
|
||||||
|
cache cache
|
||||||
|
|
||||||
|
latestReqID *atomic.Uint64
|
||||||
|
// getNextRequestID returns ID to be used for subsequent request creation.
|
||||||
|
// It is defined on Client so that our testing code can override this method
|
||||||
|
// for the sake of more predictable request IDs generation behaviour.
|
||||||
|
getNextRequestID func() uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options defines options for the RPC client.
|
// Options defines options for the RPC client.
|
||||||
|
@ -56,6 +67,9 @@ type Options struct {
|
||||||
|
|
||||||
// cache stores cache values for the RPC client methods.
|
// cache stores cache values for the RPC client methods.
|
||||||
type cache struct {
|
type cache struct {
|
||||||
|
initDone bool
|
||||||
|
network netmode.Magic
|
||||||
|
stateRootInHeader bool
|
||||||
calculateValidUntilBlock calculateValidUntilBlockCache
|
calculateValidUntilBlock calculateValidUntilBlockCache
|
||||||
nativeHashes map[string]util.Uint160
|
nativeHashes map[string]util.Uint160
|
||||||
}
|
}
|
||||||
|
@ -70,10 +84,19 @@ type calculateValidUntilBlockCache struct {
|
||||||
// New returns a new Client ready to use. You should call Init method to
|
// New returns a new Client ready to use. You should call Init method to
|
||||||
// initialize network magic the client is operating on.
|
// initialize network magic the client is operating on.
|
||||||
func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
url, err := url.Parse(endpoint)
|
cl := new(Client)
|
||||||
|
err := initClient(ctx, cl, endpoint, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return cl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func initClient(ctx context.Context, cl *Client, endpoint string, opts Options) error {
|
||||||
|
url, err := url.Parse(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if opts.DialTimeout <= 0 {
|
if opts.DialTimeout <= 0 {
|
||||||
opts.DialTimeout = defaultDialTimeout
|
opts.DialTimeout = defaultDialTimeout
|
||||||
|
@ -97,33 +120,41 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
// if opts.Cert != "" && opts.Key != "" {
|
// if opts.Cert != "" && opts.Key != "" {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
cl := &Client{
|
cl.ctx = ctx
|
||||||
ctx: ctx,
|
cl.cli = httpClient
|
||||||
cli: httpClient,
|
cl.endpoint = url
|
||||||
endpoint: url,
|
cl.cache = cache{
|
||||||
cache: cache{
|
nativeHashes: make(map[string]util.Uint160),
|
||||||
nativeHashes: make(map[string]util.Uint160),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
cl.latestReqID = atomic.NewUint64(0)
|
||||||
|
cl.getNextRequestID = (cl).getRequestID
|
||||||
cl.opts = opts
|
cl.opts = opts
|
||||||
cl.requestF = cl.makeHTTPRequest
|
cl.requestF = cl.makeHTTPRequest
|
||||||
return cl, nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getRequestID() uint64 {
|
||||||
|
return c.latestReqID.Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init sets magic of the network client connected to, stateRootInHeader option
|
// Init sets magic of the network client connected to, stateRootInHeader option
|
||||||
// and native NEO, GAS and Policy contracts scripthashes. This method should be
|
// and native NEO, GAS and Policy contracts scripthashes. This method should be
|
||||||
// called before any transaction-, header- or block-related requests in order to
|
// called before any header- or block-related requests in order to deserialize
|
||||||
// deserialize responses properly.
|
// responses properly.
|
||||||
func (c *Client) Init() error {
|
func (c *Client) Init() error {
|
||||||
version, err := c.GetVersion()
|
version, err := c.GetVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get network magic: %w", err)
|
return fmt.Errorf("failed to get network magic: %w", err)
|
||||||
}
|
}
|
||||||
c.network = version.Protocol.Network
|
|
||||||
c.stateRootInHeader = version.Protocol.StateRootInHeader
|
c.cacheLock.Lock()
|
||||||
|
defer c.cacheLock.Unlock()
|
||||||
|
|
||||||
|
c.cache.network = version.Protocol.Network
|
||||||
|
c.cache.stateRootInHeader = version.Protocol.StateRootInHeader
|
||||||
if version.Protocol.MillisecondsPerBlock == 0 {
|
if version.Protocol.MillisecondsPerBlock == 0 {
|
||||||
c.network = version.Magic
|
c.cache.network = version.Magic
|
||||||
c.stateRootInHeader = version.StateRootInHeader
|
c.cache.stateRootInHeader = version.StateRootInHeader
|
||||||
}
|
}
|
||||||
neoContractHash, err := c.GetContractStateByAddressOrName(nativenames.Neo)
|
neoContractHash, err := c.GetContractStateByAddressOrName(nativenames.Neo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -140,7 +171,7 @@ func (c *Client) Init() error {
|
||||||
return fmt.Errorf("failed to get Policy contract scripthash: %w", err)
|
return fmt.Errorf("failed to get Policy contract scripthash: %w", err)
|
||||||
}
|
}
|
||||||
c.cache.nativeHashes[nativenames.Policy] = policyContractHash.Hash
|
c.cache.nativeHashes[nativenames.Policy] = policyContractHash.Hash
|
||||||
c.initDone = true
|
c.cache.initDone = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,7 +180,7 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{
|
||||||
JSONRPC: request.JSONRPCVersion,
|
JSONRPC: request.JSONRPCVersion,
|
||||||
Method: method,
|
Method: method,
|
||||||
RawParams: p.Values,
|
RawParams: p.Values,
|
||||||
ID: 1,
|
ID: c.getNextRequestID(),
|
||||||
}
|
}
|
||||||
|
|
||||||
raw, err := c.requestF(&r)
|
raw, err := c.requestF(&r)
|
||||||
|
|
|
@ -46,9 +46,6 @@ func (c *Client) NEP11TokenInfo(tokenHash util.Uint160) (*wallet.Token, error) {
|
||||||
// given account and sends it to the network returning just a hash of it.
|
// given account and sends it to the network returning just a hash of it.
|
||||||
func (c *Client) TransferNEP11(acc *wallet.Account, to util.Uint160,
|
func (c *Client) TransferNEP11(acc *wallet.Account, to util.Uint160,
|
||||||
tokenHash util.Uint160, tokenID string, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) {
|
tokenHash util.Uint160, tokenID string, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) {
|
||||||
if !c.initDone {
|
|
||||||
return util.Uint256{}, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
tx, err := c.CreateNEP11TransferTx(acc, tokenHash, gas, cosigners, to, tokenID, data)
|
tx, err := c.CreateNEP11TransferTx(acc, tokenHash, gas, cosigners, to, tokenID, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.Uint256{}, err
|
return util.Uint256{}, err
|
||||||
|
@ -144,9 +141,6 @@ func (c *Client) NEP11NDOwnerOf(tokenHash util.Uint160, tokenID []byte) (util.Ui
|
||||||
// sends it to the network returning just a hash of it.
|
// sends it to the network returning just a hash of it.
|
||||||
func (c *Client) TransferNEP11D(acc *wallet.Account, to util.Uint160,
|
func (c *Client) TransferNEP11D(acc *wallet.Account, to util.Uint160,
|
||||||
tokenHash util.Uint160, amount int64, tokenID []byte, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) {
|
tokenHash util.Uint160, amount int64, tokenID []byte, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) {
|
||||||
if !c.initDone {
|
|
||||||
return util.Uint256{}, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
from, err := address.StringToUint160(acc.Address)
|
from, err := address.StringToUint160(acc.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.Uint256{}, fmt.Errorf("bad account address: %w", err)
|
return util.Uint256{}, fmt.Errorf("bad account address: %w", err)
|
||||||
|
|
|
@ -142,10 +142,6 @@ func (c *Client) CreateTxFromScript(script []byte, acc *wallet.Account, sysFee,
|
||||||
// impossible (e.g. due to locked cosigner's account) an error is returned.
|
// impossible (e.g. due to locked cosigner's account) an error is returned.
|
||||||
func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util.Uint160,
|
func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util.Uint160,
|
||||||
amount int64, gas int64, data interface{}, cosigners []SignerAccount) (util.Uint256, error) {
|
amount int64, gas int64, data interface{}, cosigners []SignerAccount) (util.Uint256, error) {
|
||||||
if !c.initDone {
|
|
||||||
return util.Uint256{}, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := c.CreateNEP17TransferTx(acc, to, token, amount, gas, data, cosigners)
|
tx, err := c.CreateNEP17TransferTx(acc, to, token, amount, gas, data, cosigners)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.Uint256{}, err
|
return util.Uint256{}, err
|
||||||
|
@ -156,10 +152,6 @@ func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util.
|
||||||
|
|
||||||
// MultiTransferNEP17 is similar to TransferNEP17, buf allows to have multiple recipients.
|
// MultiTransferNEP17 is similar to TransferNEP17, buf allows to have multiple recipients.
|
||||||
func (c *Client) MultiTransferNEP17(acc *wallet.Account, gas int64, recipients []TransferTarget, cosigners []SignerAccount) (util.Uint256, error) {
|
func (c *Client) MultiTransferNEP17(acc *wallet.Account, gas int64, recipients []TransferTarget, cosigners []SignerAccount) (util.Uint256, error) {
|
||||||
if !c.initDone {
|
|
||||||
return util.Uint256{}, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := c.CreateNEP17MultiTransferTx(acc, gas, recipients, cosigners)
|
tx, err := c.CreateNEP17MultiTransferTx(acc, gas, recipients, cosigners)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.Uint256{}, err
|
return util.Uint256{}, err
|
||||||
|
|
|
@ -10,25 +10,16 @@ import (
|
||||||
|
|
||||||
// GetFeePerByte invokes `getFeePerByte` method on a native Policy contract.
|
// GetFeePerByte invokes `getFeePerByte` method on a native Policy contract.
|
||||||
func (c *Client) GetFeePerByte() (int64, error) {
|
func (c *Client) GetFeePerByte() (int64, error) {
|
||||||
if !c.initDone {
|
|
||||||
return 0, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
return c.invokeNativePolicyMethod("getFeePerByte")
|
return c.invokeNativePolicyMethod("getFeePerByte")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetExecFeeFactor invokes `getExecFeeFactor` method on a native Policy contract.
|
// GetExecFeeFactor invokes `getExecFeeFactor` method on a native Policy contract.
|
||||||
func (c *Client) GetExecFeeFactor() (int64, error) {
|
func (c *Client) GetExecFeeFactor() (int64, error) {
|
||||||
if !c.initDone {
|
|
||||||
return 0, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
return c.invokeNativePolicyMethod("getExecFeeFactor")
|
return c.invokeNativePolicyMethod("getExecFeeFactor")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStoragePrice invokes `getStoragePrice` method on a native Policy contract.
|
// GetStoragePrice invokes `getStoragePrice` method on a native Policy contract.
|
||||||
func (c *Client) GetStoragePrice() (int64, error) {
|
func (c *Client) GetStoragePrice() (int64, error) {
|
||||||
if !c.initDone {
|
|
||||||
return 0, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
return c.invokeNativePolicyMethod("getStoragePrice")
|
return c.invokeNativePolicyMethod("getStoragePrice")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,10 +34,11 @@ func (c *Client) GetMaxNotValidBeforeDelta() (int64, error) {
|
||||||
|
|
||||||
// invokeNativePolicy method invokes Get* method on a native Policy contract.
|
// invokeNativePolicy method invokes Get* method on a native Policy contract.
|
||||||
func (c *Client) invokeNativePolicyMethod(operation string) (int64, error) {
|
func (c *Client) invokeNativePolicyMethod(operation string) (int64, error) {
|
||||||
if !c.initDone {
|
policyHash, err := c.GetNativeContractHash(nativenames.Policy)
|
||||||
return 0, errNetworkNotInitialized
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to get native Policy hash: %w", err)
|
||||||
}
|
}
|
||||||
return c.invokeNativeGetMethod(c.cache.nativeHashes[nativenames.Policy], operation)
|
return c.invokeNativeGetMethod(policyHash, operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int64, error) {
|
func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int64, error) {
|
||||||
|
@ -63,10 +55,11 @@ func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int
|
||||||
|
|
||||||
// IsBlocked invokes `isBlocked` method on native Policy contract.
|
// IsBlocked invokes `isBlocked` method on native Policy contract.
|
||||||
func (c *Client) IsBlocked(hash util.Uint160) (bool, error) {
|
func (c *Client) IsBlocked(hash util.Uint160) (bool, error) {
|
||||||
if !c.initDone {
|
policyHash, err := c.GetNativeContractHash(nativenames.Policy)
|
||||||
return false, errNetworkNotInitialized
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to get native Policy hash: %w", err)
|
||||||
}
|
}
|
||||||
result, err := c.InvokeFunction(c.cache.nativeHashes[nativenames.Policy], "isBlocked", []smartcontract.Parameter{{
|
result, err := c.InvokeFunction(policyHash, "isBlocked", []smartcontract.Parameter{{
|
||||||
Type: smartcontract.Hash160Type,
|
Type: smartcontract.Hash160Type,
|
||||||
Value: hash,
|
Value: hash,
|
||||||
}}, nil)
|
}}, nil)
|
||||||
|
|
|
@ -94,14 +94,15 @@ func (c *Client) getBlock(params request.RawParams) (*block.Block, error) {
|
||||||
err error
|
err error
|
||||||
b *block.Block
|
b *block.Block
|
||||||
)
|
)
|
||||||
if !c.initDone {
|
|
||||||
return nil, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
if err = c.performRequest("getblock", params, &resp); err != nil {
|
if err = c.performRequest("getblock", params, &resp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
r := io.NewBinReaderFromBuf(resp)
|
r := io.NewBinReaderFromBuf(resp)
|
||||||
b = block.New(c.StateRootInHeader())
|
sr, err := c.StateRootInHeader()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
b = block.New(sr)
|
||||||
b.DecodeBinary(r)
|
b.DecodeBinary(r)
|
||||||
if r.Err != nil {
|
if r.Err != nil {
|
||||||
return nil, r.Err
|
return nil, r.Err
|
||||||
|
@ -127,9 +128,11 @@ func (c *Client) getBlockVerbose(params request.RawParams) (*result.Block, error
|
||||||
resp = &result.Block{}
|
resp = &result.Block{}
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if !c.initDone {
|
sr, err := c.StateRootInHeader()
|
||||||
return nil, errNetworkNotInitialized
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
resp.Header.StateRootEnabled = sr
|
||||||
if err = c.performRequest("getblock", params, resp); err != nil {
|
if err = c.performRequest("getblock", params, resp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -157,14 +160,16 @@ func (c *Client) GetBlockHeader(hash util.Uint256) (*block.Header, error) {
|
||||||
resp []byte
|
resp []byte
|
||||||
h *block.Header
|
h *block.Header
|
||||||
)
|
)
|
||||||
if !c.initDone {
|
|
||||||
return nil, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
if err := c.performRequest("getblockheader", params, &resp); err != nil {
|
if err := c.performRequest("getblockheader", params, &resp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
sr, err := c.StateRootInHeader()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
r := io.NewBinReaderFromBuf(resp)
|
r := io.NewBinReaderFromBuf(resp)
|
||||||
h = new(block.Header)
|
h = new(block.Header)
|
||||||
|
h.StateRootEnabled = sr
|
||||||
h.DecodeBinary(r)
|
h.DecodeBinary(r)
|
||||||
if r.Err != nil {
|
if r.Err != nil {
|
||||||
return nil, r.Err
|
return nil, r.Err
|
||||||
|
@ -266,6 +271,14 @@ func (c *Client) GetNativeContracts() ([]state.NativeContract, error) {
|
||||||
if err := c.performRequest("getnativecontracts", params, &resp); err != nil {
|
if err := c.performRequest("getnativecontracts", params, &resp); err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update native contract hashes.
|
||||||
|
c.cacheLock.Lock()
|
||||||
|
for _, cs := range resp {
|
||||||
|
c.cache.nativeHashes[cs.Manifest.Name] = cs.Hash
|
||||||
|
}
|
||||||
|
c.cacheLock.Unlock()
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -399,17 +412,13 @@ func (c *Client) GetRawMemPool() ([]util.Uint256, error) {
|
||||||
return *resp, nil
|
return *resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRawTransaction returns a transaction by hash. You should initialize network magic
|
// GetRawTransaction returns a transaction by hash.
|
||||||
// with Init before calling GetRawTransaction.
|
|
||||||
func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction, error) {
|
func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction, error) {
|
||||||
var (
|
var (
|
||||||
params = request.NewRawParams(hash.StringLE())
|
params = request.NewRawParams(hash.StringLE())
|
||||||
resp []byte
|
resp []byte
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if !c.initDone {
|
|
||||||
return nil, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
if err = c.performRequest("getrawtransaction", params, &resp); err != nil {
|
if err = c.performRequest("getrawtransaction", params, &resp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -421,8 +430,7 @@ func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction,
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRawTransactionVerbose returns a transaction wrapper with additional
|
// GetRawTransactionVerbose returns a transaction wrapper with additional
|
||||||
// metadata by transaction's hash. You should initialize network magic
|
// metadata by transaction's hash.
|
||||||
// with Init before calling GetRawTransactionVerbose.
|
|
||||||
// NOTE: to get transaction.ID and transaction.Size, use t.Hash() and io.GetVarSize(t) respectively.
|
// NOTE: to get transaction.ID and transaction.Size, use t.Hash() and io.GetVarSize(t) respectively.
|
||||||
func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.TransactionOutputRaw, error) {
|
func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.TransactionOutputRaw, error) {
|
||||||
var (
|
var (
|
||||||
|
@ -430,9 +438,6 @@ func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.Transactio
|
||||||
resp = &result.TransactionOutputRaw{}
|
resp = &result.TransactionOutputRaw{}
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if !c.initDone {
|
|
||||||
return nil, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
if err = c.performRequest("getrawtransaction", params, resp); err != nil {
|
if err = c.performRequest("getrawtransaction", params, resp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -687,7 +692,11 @@ func (c *Client) SignAndPushTx(tx *transaction.Transaction, acc *wallet.Account,
|
||||||
txHash util.Uint256
|
txHash util.Uint256
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if err = acc.SignTx(c.GetNetwork(), tx); err != nil {
|
m, err := c.GetNetwork()
|
||||||
|
if err != nil {
|
||||||
|
return txHash, fmt.Errorf("failed to sign tx: %w", err)
|
||||||
|
}
|
||||||
|
if err = acc.SignTx(m, tx); err != nil {
|
||||||
return txHash, fmt.Errorf("failed to sign tx: %w", err)
|
return txHash, fmt.Errorf("failed to sign tx: %w", err)
|
||||||
}
|
}
|
||||||
// try to add witnesses for the rest of the signers
|
// try to add witnesses for the rest of the signers
|
||||||
|
@ -695,7 +704,7 @@ func (c *Client) SignAndPushTx(tx *transaction.Transaction, acc *wallet.Account,
|
||||||
var isOk bool
|
var isOk bool
|
||||||
for _, cosigner := range cosigners {
|
for _, cosigner := range cosigners {
|
||||||
if signer.Account == cosigner.Signer.Account {
|
if signer.Account == cosigner.Signer.Account {
|
||||||
err = cosigner.Account.SignTx(c.GetNetwork(), tx)
|
err = cosigner.Account.SignTx(m, tx)
|
||||||
if err != nil { // then account is non-contract-based and locked, but let's provide more detailed error
|
if err != nil { // then account is non-contract-based and locked, but let's provide more detailed error
|
||||||
if paramNum := len(cosigner.Account.Contract.Parameters); paramNum != 0 && cosigner.Account.Contract.Deployed {
|
if paramNum := len(cosigner.Account.Contract.Parameters); paramNum != 0 && cosigner.Account.Contract.Deployed {
|
||||||
return txHash, fmt.Errorf("failed to add contract-based witness for signer #%d (%s): "+
|
return txHash, fmt.Errorf("failed to add contract-based witness for signer #%d (%s): "+
|
||||||
|
@ -771,9 +780,6 @@ func getSigners(sender *wallet.Account, cosigners []SignerAccount) ([]transactio
|
||||||
// Note: client should be initialized before SignAndPushP2PNotaryRequest call.
|
// Note: client should be initialized before SignAndPushP2PNotaryRequest call.
|
||||||
func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fallbackScript []byte, fallbackSysFee int64, fallbackNetFee int64, fallbackValidFor uint32, acc *wallet.Account) (*payload.P2PNotaryRequest, error) {
|
func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fallbackScript []byte, fallbackSysFee int64, fallbackNetFee int64, fallbackValidFor uint32, acc *wallet.Account) (*payload.P2PNotaryRequest, error) {
|
||||||
var err error
|
var err error
|
||||||
if !c.initDone {
|
|
||||||
return nil, errNetworkNotInitialized
|
|
||||||
}
|
|
||||||
notaryHash, err := c.GetNativeContractHash(nativenames.Notary)
|
notaryHash, err := c.GetNativeContractHash(nativenames.Notary)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get native Notary hash: %w", err)
|
return nil, fmt.Errorf("failed to get native Notary hash: %w", err)
|
||||||
|
@ -835,7 +841,11 @@ func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fa
|
||||||
VerificationScript: []byte{},
|
VerificationScript: []byte{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if err = acc.SignTx(c.GetNetwork(), fallbackTx); err != nil {
|
m, err := c.GetNetwork()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to sign fallback tx: %w", err)
|
||||||
|
}
|
||||||
|
if err = acc.SignTx(m, fallbackTx); err != nil {
|
||||||
return nil, fmt.Errorf("failed to sign fallback tx: %w", err)
|
return nil, fmt.Errorf("failed to sign fallback tx: %w", err)
|
||||||
}
|
}
|
||||||
fallbackHash := fallbackTx.Hash()
|
fallbackHash := fallbackTx.Hash()
|
||||||
|
@ -844,7 +854,7 @@ func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fa
|
||||||
FallbackTransaction: fallbackTx,
|
FallbackTransaction: fallbackTx,
|
||||||
}
|
}
|
||||||
req.Witness = transaction.Witness{
|
req.Witness = transaction.Witness{
|
||||||
InvocationScript: append([]byte{byte(opcode.PUSHDATA1), 64}, acc.PrivateKey().SignHashable(uint32(c.GetNetwork()), req)...),
|
InvocationScript: append([]byte{byte(opcode.PUSHDATA1), 64}, acc.PrivateKey().SignHashable(uint32(m), req)...),
|
||||||
VerificationScript: acc.GetVerificationScript(),
|
VerificationScript: acc.GetVerificationScript(),
|
||||||
}
|
}
|
||||||
actualHash, err := c.SubmitP2PNotaryRequest(req)
|
actualHash, err := c.SubmitP2PNotaryRequest(req)
|
||||||
|
@ -922,18 +932,23 @@ func (c *Client) CalculateValidUntilBlock() (uint32, error) {
|
||||||
return result, fmt.Errorf("can't get block count: %w", err)
|
return result, fmt.Errorf("can't get block count: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.cacheLock.RLock()
|
||||||
if c.cache.calculateValidUntilBlock.expiresAt > blockCount {
|
if c.cache.calculateValidUntilBlock.expiresAt > blockCount {
|
||||||
validatorsCount = c.cache.calculateValidUntilBlock.validatorsCount
|
validatorsCount = c.cache.calculateValidUntilBlock.validatorsCount
|
||||||
|
c.cacheLock.RUnlock()
|
||||||
} else {
|
} else {
|
||||||
|
c.cacheLock.RUnlock()
|
||||||
validators, err := c.GetNextBlockValidators()
|
validators, err := c.GetNextBlockValidators()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, fmt.Errorf("can't get validators: %w", err)
|
return result, fmt.Errorf("can't get validators: %w", err)
|
||||||
}
|
}
|
||||||
validatorsCount = uint32(len(validators))
|
validatorsCount = uint32(len(validators))
|
||||||
|
c.cacheLock.Lock()
|
||||||
c.cache.calculateValidUntilBlock = calculateValidUntilBlockCache{
|
c.cache.calculateValidUntilBlock = calculateValidUntilBlockCache{
|
||||||
validatorsCount: validatorsCount,
|
validatorsCount: validatorsCount,
|
||||||
expiresAt: blockCount + cacheTimeout,
|
expiresAt: blockCount + cacheTimeout,
|
||||||
}
|
}
|
||||||
|
c.cacheLock.Unlock()
|
||||||
}
|
}
|
||||||
return blockCount + validatorsCount + 1, nil
|
return blockCount + validatorsCount + 1, nil
|
||||||
}
|
}
|
||||||
|
@ -990,18 +1005,33 @@ func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNetwork returns the network magic of the RPC node client connected to.
|
// GetNetwork returns the network magic of the RPC node client connected to.
|
||||||
func (c *Client) GetNetwork() netmode.Magic {
|
func (c *Client) GetNetwork() (netmode.Magic, error) {
|
||||||
return c.network
|
c.cacheLock.RLock()
|
||||||
|
defer c.cacheLock.RUnlock()
|
||||||
|
|
||||||
|
if !c.cache.initDone {
|
||||||
|
return 0, errNetworkNotInitialized
|
||||||
|
}
|
||||||
|
return c.cache.network, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateRootInHeader returns true if state root is contained in block header.
|
// StateRootInHeader returns true if state root is contained in block header.
|
||||||
func (c *Client) StateRootInHeader() bool {
|
// You should initialize Client cache with Init() before calling StateRootInHeader.
|
||||||
return c.stateRootInHeader
|
func (c *Client) StateRootInHeader() (bool, error) {
|
||||||
|
c.cacheLock.RLock()
|
||||||
|
defer c.cacheLock.RUnlock()
|
||||||
|
|
||||||
|
if !c.cache.initDone {
|
||||||
|
return false, errNetworkNotInitialized
|
||||||
|
}
|
||||||
|
return c.cache.stateRootInHeader, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNativeContractHash returns native contract hash by its name.
|
// GetNativeContractHash returns native contract hash by its name.
|
||||||
func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
|
func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
|
||||||
|
c.cacheLock.RLock()
|
||||||
hash, ok := c.cache.nativeHashes[name]
|
hash, ok := c.cache.nativeHashes[name]
|
||||||
|
c.cacheLock.RUnlock()
|
||||||
if ok {
|
if ok {
|
||||||
return hash, nil
|
return hash, nil
|
||||||
}
|
}
|
||||||
|
@ -1009,6 +1039,8 @@ func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.Uint160{}, err
|
return util.Uint160{}, err
|
||||||
}
|
}
|
||||||
|
c.cacheLock.Lock()
|
||||||
c.cache.nativeHashes[name] = cs.Hash
|
c.cache.nativeHashes[name] = cs.Hash
|
||||||
|
c.cacheLock.Unlock()
|
||||||
return cs.Hash, nil
|
return cs.Hash, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -1704,6 +1705,7 @@ func TestRPCClients(t *testing.T) {
|
||||||
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
c, err := New(ctx, endpoint, opts)
|
c, err := New(ctx, endpoint, opts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, c.Init())
|
require.NoError(t, c.Init())
|
||||||
return c, nil
|
return c, nil
|
||||||
})
|
})
|
||||||
|
@ -1712,6 +1714,7 @@ func TestRPCClients(t *testing.T) {
|
||||||
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
|
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
return &wsc.Client, nil
|
return &wsc.Client, nil
|
||||||
})
|
})
|
||||||
|
@ -1731,6 +1734,7 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
|
|
||||||
actual, err := testCase.invoke(c)
|
actual, err := testCase.invoke(c)
|
||||||
if testCase.fails {
|
if testCase.fails {
|
||||||
|
@ -1754,14 +1758,14 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
|
||||||
|
|
||||||
endpoint := srv.URL
|
endpoint := srv.URL
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
c, err := newClient(context.TODO(), endpoint, opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range testBatch {
|
for _, testCase := range testBatch {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
_, err := testCase.invoke(c)
|
c, err := newClient(context.TODO(), endpoint, opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
|
_, err = testCase.invoke(c)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1877,6 +1881,7 @@ func TestCalculateValidUntilBlock(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, c.Init())
|
require.NoError(t, c.Init())
|
||||||
|
|
||||||
validUntilBlock, err := c.CalculateValidUntilBlock()
|
validUntilBlock, err := c.CalculateValidUntilBlock()
|
||||||
|
@ -1912,9 +1917,11 @@ func TestGetNetwork(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
// network was not initialised
|
// network was not initialised
|
||||||
require.Equal(t, netmode.Magic(0), c.GetNetwork())
|
_, err = c.GetNetwork()
|
||||||
require.Equal(t, false, c.initDone)
|
require.True(t, errors.Is(err, errNetworkNotInitialized))
|
||||||
|
require.Equal(t, false, c.cache.initDone)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("good", func(t *testing.T) {
|
t.Run("good", func(t *testing.T) {
|
||||||
|
@ -1922,8 +1929,11 @@ func TestGetNetwork(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, c.Init())
|
require.NoError(t, c.Init())
|
||||||
require.Equal(t, netmode.UnitTestNet, c.GetNetwork())
|
m, err := c.GetNetwork()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, netmode.UnitTestNet, m)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1941,6 +1951,7 @@ func TestUninitedClient(t *testing.T) {
|
||||||
|
|
||||||
c, err := New(context.TODO(), endpoint, opts)
|
c, err := New(context.TODO(), endpoint, opts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
|
|
||||||
_, err = c.GetBlockByIndex(0)
|
_, err = c.GetBlockByIndex(0)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
@ -1966,3 +1977,7 @@ func newTestNEF(script []byte) nef.File {
|
||||||
ne.Checksum = ne.CalculateChecksum()
|
ne.Checksum = ne.CalculateChecksum()
|
||||||
return ne
|
return ne
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTestRequestID() uint64 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
@ -20,6 +22,8 @@ import (
|
||||||
// servers. It's supposed to be faster than Client because it has persistent
|
// servers. It's supposed to be faster than Client because it has persistent
|
||||||
// connection to the server and at the same time is exposes some functionality
|
// connection to the server and at the same time is exposes some functionality
|
||||||
// that is only provided via websockets (like event subscription mechanism).
|
// that is only provided via websockets (like event subscription mechanism).
|
||||||
|
// WSClient is thread-safe and can be used from multiple goroutines to perform
|
||||||
|
// RPC requests.
|
||||||
type WSClient struct {
|
type WSClient struct {
|
||||||
Client
|
Client
|
||||||
// Notifications is a channel that is used to send events received from
|
// Notifications is a channel that is used to send events received from
|
||||||
|
@ -30,12 +34,16 @@ type WSClient struct {
|
||||||
// be closed, so make sure to handle this.
|
// be closed, so make sure to handle this.
|
||||||
Notifications chan Notification
|
Notifications chan Notification
|
||||||
|
|
||||||
ws *websocket.Conn
|
ws *websocket.Conn
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
responses chan *response.Raw
|
requests chan *request.Raw
|
||||||
requests chan *request.Raw
|
shutdown chan struct{}
|
||||||
shutdown chan struct{}
|
|
||||||
subscriptions map[string]bool
|
subscriptionsLock sync.RWMutex
|
||||||
|
subscriptions map[string]bool
|
||||||
|
|
||||||
|
respLock sync.RWMutex
|
||||||
|
respChannels map[uint64]chan *response.Raw
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notification represents server-generated notification for client subscriptions.
|
// Notification represents server-generated notification for client subscriptions.
|
||||||
|
@ -73,29 +81,29 @@ const (
|
||||||
// You should call Init method to initialize network magic the client is
|
// You should call Init method to initialize network magic the client is
|
||||||
// operating on.
|
// operating on.
|
||||||
func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) {
|
func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) {
|
||||||
cl, err := New(ctx, endpoint, opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cl.cli = nil
|
|
||||||
|
|
||||||
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
|
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
|
||||||
ws, _, err := dialer.Dial(endpoint, nil)
|
ws, _, err := dialer.Dial(endpoint, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
wsc := &WSClient{
|
wsc := &WSClient{
|
||||||
Client: *cl,
|
Client: Client{},
|
||||||
Notifications: make(chan Notification),
|
Notifications: make(chan Notification),
|
||||||
|
|
||||||
ws: ws,
|
ws: ws,
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
responses: make(chan *response.Raw),
|
respChannels: make(map[uint64]chan *response.Raw),
|
||||||
requests: make(chan *request.Raw),
|
requests: make(chan *request.Raw),
|
||||||
subscriptions: make(map[string]bool),
|
subscriptions: make(map[string]bool),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = initClient(ctx, &wsc.Client, endpoint, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
wsc.Client.cli = nil
|
||||||
|
|
||||||
go wsc.wsReader()
|
go wsc.wsReader()
|
||||||
go wsc.wsWriter()
|
go wsc.wsWriter()
|
||||||
wsc.requestF = wsc.makeWsRequest
|
wsc.requestF = wsc.makeWsRequest
|
||||||
|
@ -141,7 +149,12 @@ readloop:
|
||||||
var val interface{}
|
var val interface{}
|
||||||
switch event {
|
switch event {
|
||||||
case response.BlockEventID:
|
case response.BlockEventID:
|
||||||
val = block.New(c.StateRootInHeader())
|
sr, err := c.StateRootInHeader()
|
||||||
|
if err != nil {
|
||||||
|
// Client is not initialised.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
val = block.New(sr)
|
||||||
case response.TransactionEventID:
|
case response.TransactionEventID:
|
||||||
val = &transaction.Transaction{}
|
val = &transaction.Transaction{}
|
||||||
case response.NotificationEventID:
|
case response.NotificationEventID:
|
||||||
|
@ -170,14 +183,27 @@ readloop:
|
||||||
resp.JSONRPC = rr.JSONRPC
|
resp.JSONRPC = rr.JSONRPC
|
||||||
resp.Error = rr.Error
|
resp.Error = rr.Error
|
||||||
resp.Result = rr.Result
|
resp.Result = rr.Result
|
||||||
c.responses <- resp
|
id, err := strconv.Atoi(string(resp.ID))
|
||||||
|
if err != nil {
|
||||||
|
break // Malformed response (invalid response ID).
|
||||||
|
}
|
||||||
|
ch := c.getResponseChannel(uint64(id))
|
||||||
|
if ch == nil {
|
||||||
|
break // Unknown response (unexpected response ID).
|
||||||
|
}
|
||||||
|
ch <- resp
|
||||||
} else {
|
} else {
|
||||||
// Malformed response, neither valid request, nor valid response.
|
// Malformed response, neither valid request, nor valid response.
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
close(c.done)
|
close(c.done)
|
||||||
close(c.responses)
|
c.respLock.Lock()
|
||||||
|
for _, ch := range c.respChannels {
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
c.respChannels = nil
|
||||||
|
c.respLock.Unlock()
|
||||||
close(c.Notifications)
|
close(c.Notifications)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -212,16 +238,41 @@ func (c *WSClient) wsWriter() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) registerRespChannel(id uint64, ch chan *response.Raw) {
|
||||||
|
c.respLock.Lock()
|
||||||
|
defer c.respLock.Unlock()
|
||||||
|
c.respChannels[id] = ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) unregisterRespChannel(id uint64) {
|
||||||
|
c.respLock.Lock()
|
||||||
|
defer c.respLock.Unlock()
|
||||||
|
if ch, ok := c.respChannels[id]; ok {
|
||||||
|
delete(c.respChannels, id)
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) getResponseChannel(id uint64) chan *response.Raw {
|
||||||
|
c.respLock.RLock()
|
||||||
|
defer c.respLock.RUnlock()
|
||||||
|
return c.respChannels[id]
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
|
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
|
||||||
|
ch := make(chan *response.Raw)
|
||||||
|
c.registerRespChannel(r.ID, ch)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
return nil, errors.New("connection lost")
|
return nil, errors.New("connection lost before sending the request")
|
||||||
case c.requests <- r:
|
case c.requests <- r:
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
return nil, errors.New("connection lost")
|
return nil, errors.New("connection lost while waiting for the response")
|
||||||
case resp := <-c.responses:
|
case resp := <-ch:
|
||||||
|
c.unregisterRespChannel(r.ID)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -232,6 +283,10 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
|
||||||
if err := c.performRequest("subscribe", params, &resp); err != nil {
|
if err := c.performRequest("subscribe", params, &resp); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.subscriptionsLock.Lock()
|
||||||
|
defer c.subscriptionsLock.Unlock()
|
||||||
|
|
||||||
c.subscriptions[resp] = true
|
c.subscriptions[resp] = true
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
@ -239,6 +294,9 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
|
||||||
func (c *WSClient) performUnsubscription(id string) error {
|
func (c *WSClient) performUnsubscription(id string) error {
|
||||||
var resp bool
|
var resp bool
|
||||||
|
|
||||||
|
c.subscriptionsLock.Lock()
|
||||||
|
defer c.subscriptionsLock.Unlock()
|
||||||
|
|
||||||
if !c.subscriptions[id] {
|
if !c.subscriptions[id] {
|
||||||
return errors.New("no subscription with this ID")
|
return errors.New("no subscription with this ID")
|
||||||
}
|
}
|
||||||
|
@ -320,11 +378,18 @@ func (c *WSClient) Unsubscribe(id string) error {
|
||||||
|
|
||||||
// UnsubscribeAll removes all active subscriptions of current client.
|
// UnsubscribeAll removes all active subscriptions of current client.
|
||||||
func (c *WSClient) UnsubscribeAll() error {
|
func (c *WSClient) UnsubscribeAll() error {
|
||||||
|
c.subscriptionsLock.Lock()
|
||||||
|
defer c.subscriptionsLock.Unlock()
|
||||||
|
|
||||||
for id := range c.subscriptions {
|
for id := range c.subscriptions {
|
||||||
err := c.performUnsubscription(id)
|
var resp bool
|
||||||
if err != nil {
|
if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if !resp {
|
||||||
|
return errors.New("unsubscribe method returned false result")
|
||||||
|
}
|
||||||
|
delete(c.subscriptions, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -15,6 +18,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWSClientClose(t *testing.T) {
|
func TestWSClientClose(t *testing.T) {
|
||||||
|
@ -45,6 +49,7 @@ func TestWSClientSubscription(t *testing.T) {
|
||||||
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
|
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
id, err := f(wsc)
|
id, err := f(wsc)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -58,6 +63,7 @@ func TestWSClientSubscription(t *testing.T) {
|
||||||
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`)
|
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`)
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
_, err = f(wsc)
|
_, err = f(wsc)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
@ -107,6 +113,7 @@ func TestWSClientUnsubscription(t *testing.T) {
|
||||||
srv := initTestServer(t, rc.response)
|
srv := initTestServer(t, rc.response)
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
rc.code(t, wsc)
|
rc.code(t, wsc)
|
||||||
})
|
})
|
||||||
|
@ -143,7 +150,9 @@ func TestWSClientEvents(t *testing.T) {
|
||||||
|
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
wsc.network = netmode.UnitTestNet
|
wsc.getNextRequestID = getTestRequestID
|
||||||
|
wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually.
|
||||||
|
wsc.cache.network = netmode.UnitTestNet
|
||||||
for range events {
|
for range events {
|
||||||
select {
|
select {
|
||||||
case _, ok = <-wsc.Notifications:
|
case _, ok = <-wsc.Notifications:
|
||||||
|
@ -166,6 +175,7 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
|
||||||
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
|
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
filter := "NONE"
|
filter := "NONE"
|
||||||
_, err = wsc.SubscribeForTransactionExecutions(&filter)
|
_, err = wsc.SubscribeForTransactionExecutions(&filter)
|
||||||
|
@ -315,7 +325,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
wsc.network = netmode.UnitTestNet
|
wsc.getNextRequestID = getTestRequestID
|
||||||
|
wsc.cache.network = netmode.UnitTestNet
|
||||||
c.clientCode(t, wsc)
|
c.clientCode(t, wsc)
|
||||||
wsc.Close()
|
wsc.Close()
|
||||||
})
|
})
|
||||||
|
@ -328,6 +339,8 @@ func TestNewWS(t *testing.T) {
|
||||||
t.Run("good", func(t *testing.T) {
|
t.Run("good", func(t *testing.T) {
|
||||||
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
|
c.cache.network = netmode.UnitTestNet
|
||||||
require.NoError(t, c.Init())
|
require.NoError(t, c.Init())
|
||||||
})
|
})
|
||||||
t.Run("bad URL", func(t *testing.T) {
|
t.Run("bad URL", func(t *testing.T) {
|
||||||
|
@ -335,3 +348,96 @@ func TestNewWS(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWSConcurrentAccess(t *testing.T) {
|
||||||
|
var ids struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
m map[int]struct{}
|
||||||
|
}
|
||||||
|
ids.m = make(map[int]struct{})
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.URL.Path == "/ws" && req.Method == "GET" {
|
||||||
|
var upgrader = websocket.Upgrader{}
|
||||||
|
ws, err := upgrader.Upgrade(w, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for {
|
||||||
|
err = ws.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, p, err := ws.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
r := request.NewIn()
|
||||||
|
err = json.Unmarshal(p, r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Cannot decode request body: %s", req.Body)
|
||||||
|
}
|
||||||
|
i, err := strconv.Atoi(string(r.RawID))
|
||||||
|
require.NoError(t, err)
|
||||||
|
ids.lock.Lock()
|
||||||
|
ids.m[i] = struct{}{}
|
||||||
|
ids.lock.Unlock()
|
||||||
|
var response string
|
||||||
|
// Different responses to catch possible unmarshalling errors connected with invalid IDs distribution.
|
||||||
|
switch r.Method {
|
||||||
|
case "getblockcount":
|
||||||
|
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":123}`, r.RawID)
|
||||||
|
case "getversion":
|
||||||
|
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":{"network":42,"tcpport":20332,"wsport":20342,"nonce":2153672787,"useragent":"/NEO-GO:0.73.1-pre-273-ge381358/"}}`, r.RawID)
|
||||||
|
case "getblockhash":
|
||||||
|
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":"0x157ca5e5b8cf8f84c9660502a3270b346011612bded1514a6847f877c433a9bb"}`, r.RawID)
|
||||||
|
}
|
||||||
|
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = ws.WriteMessage(1, []byte(response))
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
batchCount := 100
|
||||||
|
completed := atomic.NewInt32(0)
|
||||||
|
for i := 0; i < batchCount; i++ {
|
||||||
|
go func() {
|
||||||
|
_, err := wsc.GetBlockCount()
|
||||||
|
require.NoError(t, err)
|
||||||
|
completed.Inc()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, err := wsc.GetBlockHash(123)
|
||||||
|
require.NoError(t, err)
|
||||||
|
completed.Inc()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := wsc.GetVersion()
|
||||||
|
require.NoError(t, err)
|
||||||
|
completed.Inc()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return int(completed.Load()) == batchCount*3
|
||||||
|
}, time.Second, 100*time.Millisecond)
|
||||||
|
|
||||||
|
ids.lock.RLock()
|
||||||
|
require.True(t, len(ids.m) > batchCount)
|
||||||
|
idsList := make([]int, 0, len(ids.m))
|
||||||
|
for i := range ids.m {
|
||||||
|
idsList = append(idsList, i)
|
||||||
|
}
|
||||||
|
ids.lock.RUnlock()
|
||||||
|
|
||||||
|
sort.Ints(idsList)
|
||||||
|
require.Equal(t, 1, idsList[0])
|
||||||
|
require.Less(t, idsList[len(idsList)-1],
|
||||||
|
batchCount*3+1) // batchCount*requestsPerBatch+1
|
||||||
|
wsc.Close()
|
||||||
|
}
|
||||||
|
|
|
@ -32,12 +32,12 @@ func NewRawParams(vals ...interface{}) RawParams {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Raw represents JSON-RPC request.
|
// Raw represents JSON-RPC request on the Client side.
|
||||||
type Raw struct {
|
type Raw struct {
|
||||||
JSONRPC string `json:"jsonrpc"`
|
JSONRPC string `json:"jsonrpc"`
|
||||||
Method string `json:"method"`
|
Method string `json:"method"`
|
||||||
RawParams []interface{} `json:"params"`
|
RawParams []interface{} `json:"params"`
|
||||||
ID int `json:"id"`
|
ID uint64 `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request contains standard JSON-RPC 2.0 request and batch of
|
// Request contains standard JSON-RPC 2.0 request and batch of
|
||||||
|
|
|
@ -555,9 +555,11 @@ func TestSignAndPushP2PNotaryRequest(t *testing.T) {
|
||||||
|
|
||||||
c, err := client.New(context.Background(), httpSrv.URL, client.Options{})
|
c, err := client.New(context.Background(), httpSrv.URL, client.Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
acc, err := wallet.NewAccount()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("client wasn't initialized", func(t *testing.T) {
|
t.Run("client wasn't initialized", func(t *testing.T) {
|
||||||
_, err := c.SignAndPushP2PNotaryRequest(nil, nil, 0, 0, 0, nil)
|
_, err := c.SignAndPushP2PNotaryRequest(transaction.New([]byte{byte(opcode.RET)}, 123), []byte{byte(opcode.RET)}, -1, 0, 100, acc)
|
||||||
require.NotNil(t, err)
|
require.NotNil(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -567,8 +569,6 @@ func TestSignAndPushP2PNotaryRequest(t *testing.T) {
|
||||||
require.NotNil(t, err)
|
require.NotNil(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
acc, err := wallet.NewAccount()
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Run("bad fallback script", func(t *testing.T) {
|
t.Run("bad fallback script", func(t *testing.T) {
|
||||||
_, err := c.SignAndPushP2PNotaryRequest(nil, []byte{byte(opcode.ASSERT)}, -1, 0, 0, acc)
|
_, err := c.SignAndPushP2PNotaryRequest(nil, []byte{byte(opcode.ASSERT)}, -1, 0, 0, acc)
|
||||||
require.NotNil(t, err)
|
require.NotNil(t, err)
|
||||||
|
@ -649,7 +649,7 @@ func TestCalculateNotaryFee(t *testing.T) {
|
||||||
|
|
||||||
t.Run("client not initialized", func(t *testing.T) {
|
t.Run("client not initialized", func(t *testing.T) {
|
||||||
_, err := c.CalculateNotaryFee(0)
|
_, err := c.CalculateNotaryFee(0)
|
||||||
require.NotNil(t, err)
|
require.NoError(t, err) // Do not require client initialisation for this.
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue