Merge pull request #2367 from nspcc-dev/rpc/thread-safe

rpc: take care of RPC clients
This commit is contained in:
Roman Khimov 2022-02-24 20:11:15 +03:00 committed by GitHub
commit 870fd024c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 369 additions and 129 deletions

View file

@ -682,7 +682,11 @@ func invokeWithArgs(ctx *cli.Context, acc *wallet.Account, wall *wallet.Wallet,
if err != nil { if err != nil {
return sender, cli.NewExitError(fmt.Errorf("failed to create tx: %w", err), 1) return sender, cli.NewExitError(fmt.Errorf("failed to create tx: %w", err), 1)
} }
if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, out); err != nil { m, err := c.GetNetwork()
if err != nil {
return sender, cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1)
}
if err := paramcontext.InitAndSave(m, tx, acc, out); err != nil {
return sender, cli.NewExitError(err, 1) return sender, cli.NewExitError(err, 1)
} }
fmt.Fprintln(ctx.App.Writer, tx.Hash().StringLE()) fmt.Fprintln(ctx.App.Writer, tx.Hash().StringLE())

View file

@ -274,7 +274,11 @@ func signAndSendNEP11Transfer(ctx *cli.Context, c *client.Client, acc *wallet.Ac
tx.SystemFee += int64(sysgas) tx.SystemFee += int64(sysgas)
if outFile := ctx.String("out"); outFile != "" { if outFile := ctx.String("out"); outFile != "" {
if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, outFile); err != nil { m, err := c.GetNetwork()
if err != nil {
return cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1)
}
if err := paramcontext.InitAndSave(m, tx, acc, outFile); err != nil {
return cli.NewExitError(err, 1) return cli.NewExitError(err, 1)
} }
} else { } else {

View file

@ -647,7 +647,11 @@ func signAndSendNEP17Transfer(ctx *cli.Context, c *client.Client, acc *wallet.Ac
tx.SystemFee += int64(sysgas) tx.SystemFee += int64(sysgas)
if outFile := ctx.String("out"); outFile != "" { if outFile := ctx.String("out"); outFile != "" {
if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, outFile); err != nil { m, err := c.GetNetwork()
if err != nil {
return cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1)
}
if err := paramcontext.InitAndSave(m, tx, acc, outFile); err != nil {
return cli.NewExitError(err, 1) return cli.NewExitError(err, 1)
} }
} else { } else {

View file

@ -9,6 +9,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"time" "time"
"github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/config/netmode"
@ -16,6 +17,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/rpc/response" "github.com/nspcc-dev/neo-go/pkg/rpc/response"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"go.uber.org/atomic"
) )
const ( const (
@ -26,17 +28,26 @@ const (
) )
// Client represents the middleman for executing JSON RPC calls // Client represents the middleman for executing JSON RPC calls
// to remote NEO RPC nodes. // to remote NEO RPC nodes. Client is thread-safe and can be used from
// multiple goroutines.
type Client struct { type Client struct {
cli *http.Client cli *http.Client
endpoint *url.URL endpoint *url.URL
network netmode.Magic ctx context.Context
stateRootInHeader bool opts Options
initDone bool requestF func(*request.Raw) (*response.Raw, error)
ctx context.Context
opts Options cacheLock sync.RWMutex
requestF func(*request.Raw) (*response.Raw, error) // cache stores RPC node related information client is bound to.
cache cache // cache is mostly filled in during Init(), but can also be updated
// during regular Client lifecycle.
cache cache
latestReqID *atomic.Uint64
// getNextRequestID returns ID to be used for subsequent request creation.
// It is defined on Client so that our testing code can override this method
// for the sake of more predictable request IDs generation behaviour.
getNextRequestID func() uint64
} }
// Options defines options for the RPC client. // Options defines options for the RPC client.
@ -56,6 +67,9 @@ type Options struct {
// cache stores cache values for the RPC client methods. // cache stores cache values for the RPC client methods.
type cache struct { type cache struct {
initDone bool
network netmode.Magic
stateRootInHeader bool
calculateValidUntilBlock calculateValidUntilBlockCache calculateValidUntilBlock calculateValidUntilBlockCache
nativeHashes map[string]util.Uint160 nativeHashes map[string]util.Uint160
} }
@ -70,10 +84,19 @@ type calculateValidUntilBlockCache struct {
// New returns a new Client ready to use. You should call Init method to // New returns a new Client ready to use. You should call Init method to
// initialize network magic the client is operating on. // initialize network magic the client is operating on.
func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
url, err := url.Parse(endpoint) cl := new(Client)
err := initClient(ctx, cl, endpoint, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return cl, nil
}
func initClient(ctx context.Context, cl *Client, endpoint string, opts Options) error {
url, err := url.Parse(endpoint)
if err != nil {
return err
}
if opts.DialTimeout <= 0 { if opts.DialTimeout <= 0 {
opts.DialTimeout = defaultDialTimeout opts.DialTimeout = defaultDialTimeout
@ -97,33 +120,41 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
// if opts.Cert != "" && opts.Key != "" { // if opts.Cert != "" && opts.Key != "" {
// } // }
cl := &Client{ cl.ctx = ctx
ctx: ctx, cl.cli = httpClient
cli: httpClient, cl.endpoint = url
endpoint: url, cl.cache = cache{
cache: cache{ nativeHashes: make(map[string]util.Uint160),
nativeHashes: make(map[string]util.Uint160),
},
} }
cl.latestReqID = atomic.NewUint64(0)
cl.getNextRequestID = (cl).getRequestID
cl.opts = opts cl.opts = opts
cl.requestF = cl.makeHTTPRequest cl.requestF = cl.makeHTTPRequest
return cl, nil return nil
}
func (c *Client) getRequestID() uint64 {
return c.latestReqID.Inc()
} }
// Init sets magic of the network client connected to, stateRootInHeader option // Init sets magic of the network client connected to, stateRootInHeader option
// and native NEO, GAS and Policy contracts scripthashes. This method should be // and native NEO, GAS and Policy contracts scripthashes. This method should be
// called before any transaction-, header- or block-related requests in order to // called before any header- or block-related requests in order to deserialize
// deserialize responses properly. // responses properly.
func (c *Client) Init() error { func (c *Client) Init() error {
version, err := c.GetVersion() version, err := c.GetVersion()
if err != nil { if err != nil {
return fmt.Errorf("failed to get network magic: %w", err) return fmt.Errorf("failed to get network magic: %w", err)
} }
c.network = version.Protocol.Network
c.stateRootInHeader = version.Protocol.StateRootInHeader c.cacheLock.Lock()
defer c.cacheLock.Unlock()
c.cache.network = version.Protocol.Network
c.cache.stateRootInHeader = version.Protocol.StateRootInHeader
if version.Protocol.MillisecondsPerBlock == 0 { if version.Protocol.MillisecondsPerBlock == 0 {
c.network = version.Magic c.cache.network = version.Magic
c.stateRootInHeader = version.StateRootInHeader c.cache.stateRootInHeader = version.StateRootInHeader
} }
neoContractHash, err := c.GetContractStateByAddressOrName(nativenames.Neo) neoContractHash, err := c.GetContractStateByAddressOrName(nativenames.Neo)
if err != nil { if err != nil {
@ -140,7 +171,7 @@ func (c *Client) Init() error {
return fmt.Errorf("failed to get Policy contract scripthash: %w", err) return fmt.Errorf("failed to get Policy contract scripthash: %w", err)
} }
c.cache.nativeHashes[nativenames.Policy] = policyContractHash.Hash c.cache.nativeHashes[nativenames.Policy] = policyContractHash.Hash
c.initDone = true c.cache.initDone = true
return nil return nil
} }
@ -149,7 +180,7 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{
JSONRPC: request.JSONRPCVersion, JSONRPC: request.JSONRPCVersion,
Method: method, Method: method,
RawParams: p.Values, RawParams: p.Values,
ID: 1, ID: c.getNextRequestID(),
} }
raw, err := c.requestF(&r) raw, err := c.requestF(&r)

View file

@ -46,9 +46,6 @@ func (c *Client) NEP11TokenInfo(tokenHash util.Uint160) (*wallet.Token, error) {
// given account and sends it to the network returning just a hash of it. // given account and sends it to the network returning just a hash of it.
func (c *Client) TransferNEP11(acc *wallet.Account, to util.Uint160, func (c *Client) TransferNEP11(acc *wallet.Account, to util.Uint160,
tokenHash util.Uint160, tokenID string, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) { tokenHash util.Uint160, tokenID string, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) {
if !c.initDone {
return util.Uint256{}, errNetworkNotInitialized
}
tx, err := c.CreateNEP11TransferTx(acc, tokenHash, gas, cosigners, to, tokenID, data) tx, err := c.CreateNEP11TransferTx(acc, tokenHash, gas, cosigners, to, tokenID, data)
if err != nil { if err != nil {
return util.Uint256{}, err return util.Uint256{}, err
@ -144,9 +141,6 @@ func (c *Client) NEP11NDOwnerOf(tokenHash util.Uint160, tokenID []byte) (util.Ui
// sends it to the network returning just a hash of it. // sends it to the network returning just a hash of it.
func (c *Client) TransferNEP11D(acc *wallet.Account, to util.Uint160, func (c *Client) TransferNEP11D(acc *wallet.Account, to util.Uint160,
tokenHash util.Uint160, amount int64, tokenID []byte, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) { tokenHash util.Uint160, amount int64, tokenID []byte, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) {
if !c.initDone {
return util.Uint256{}, errNetworkNotInitialized
}
from, err := address.StringToUint160(acc.Address) from, err := address.StringToUint160(acc.Address)
if err != nil { if err != nil {
return util.Uint256{}, fmt.Errorf("bad account address: %w", err) return util.Uint256{}, fmt.Errorf("bad account address: %w", err)

View file

@ -142,10 +142,6 @@ func (c *Client) CreateTxFromScript(script []byte, acc *wallet.Account, sysFee,
// impossible (e.g. due to locked cosigner's account) an error is returned. // impossible (e.g. due to locked cosigner's account) an error is returned.
func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util.Uint160, func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util.Uint160,
amount int64, gas int64, data interface{}, cosigners []SignerAccount) (util.Uint256, error) { amount int64, gas int64, data interface{}, cosigners []SignerAccount) (util.Uint256, error) {
if !c.initDone {
return util.Uint256{}, errNetworkNotInitialized
}
tx, err := c.CreateNEP17TransferTx(acc, to, token, amount, gas, data, cosigners) tx, err := c.CreateNEP17TransferTx(acc, to, token, amount, gas, data, cosigners)
if err != nil { if err != nil {
return util.Uint256{}, err return util.Uint256{}, err
@ -156,10 +152,6 @@ func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util.
// MultiTransferNEP17 is similar to TransferNEP17, buf allows to have multiple recipients. // MultiTransferNEP17 is similar to TransferNEP17, buf allows to have multiple recipients.
func (c *Client) MultiTransferNEP17(acc *wallet.Account, gas int64, recipients []TransferTarget, cosigners []SignerAccount) (util.Uint256, error) { func (c *Client) MultiTransferNEP17(acc *wallet.Account, gas int64, recipients []TransferTarget, cosigners []SignerAccount) (util.Uint256, error) {
if !c.initDone {
return util.Uint256{}, errNetworkNotInitialized
}
tx, err := c.CreateNEP17MultiTransferTx(acc, gas, recipients, cosigners) tx, err := c.CreateNEP17MultiTransferTx(acc, gas, recipients, cosigners)
if err != nil { if err != nil {
return util.Uint256{}, err return util.Uint256{}, err

View file

@ -10,25 +10,16 @@ import (
// GetFeePerByte invokes `getFeePerByte` method on a native Policy contract. // GetFeePerByte invokes `getFeePerByte` method on a native Policy contract.
func (c *Client) GetFeePerByte() (int64, error) { func (c *Client) GetFeePerByte() (int64, error) {
if !c.initDone {
return 0, errNetworkNotInitialized
}
return c.invokeNativePolicyMethod("getFeePerByte") return c.invokeNativePolicyMethod("getFeePerByte")
} }
// GetExecFeeFactor invokes `getExecFeeFactor` method on a native Policy contract. // GetExecFeeFactor invokes `getExecFeeFactor` method on a native Policy contract.
func (c *Client) GetExecFeeFactor() (int64, error) { func (c *Client) GetExecFeeFactor() (int64, error) {
if !c.initDone {
return 0, errNetworkNotInitialized
}
return c.invokeNativePolicyMethod("getExecFeeFactor") return c.invokeNativePolicyMethod("getExecFeeFactor")
} }
// GetStoragePrice invokes `getStoragePrice` method on a native Policy contract. // GetStoragePrice invokes `getStoragePrice` method on a native Policy contract.
func (c *Client) GetStoragePrice() (int64, error) { func (c *Client) GetStoragePrice() (int64, error) {
if !c.initDone {
return 0, errNetworkNotInitialized
}
return c.invokeNativePolicyMethod("getStoragePrice") return c.invokeNativePolicyMethod("getStoragePrice")
} }
@ -43,10 +34,11 @@ func (c *Client) GetMaxNotValidBeforeDelta() (int64, error) {
// invokeNativePolicy method invokes Get* method on a native Policy contract. // invokeNativePolicy method invokes Get* method on a native Policy contract.
func (c *Client) invokeNativePolicyMethod(operation string) (int64, error) { func (c *Client) invokeNativePolicyMethod(operation string) (int64, error) {
if !c.initDone { policyHash, err := c.GetNativeContractHash(nativenames.Policy)
return 0, errNetworkNotInitialized if err != nil {
return 0, fmt.Errorf("failed to get native Policy hash: %w", err)
} }
return c.invokeNativeGetMethod(c.cache.nativeHashes[nativenames.Policy], operation) return c.invokeNativeGetMethod(policyHash, operation)
} }
func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int64, error) { func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int64, error) {
@ -63,10 +55,11 @@ func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int
// IsBlocked invokes `isBlocked` method on native Policy contract. // IsBlocked invokes `isBlocked` method on native Policy contract.
func (c *Client) IsBlocked(hash util.Uint160) (bool, error) { func (c *Client) IsBlocked(hash util.Uint160) (bool, error) {
if !c.initDone { policyHash, err := c.GetNativeContractHash(nativenames.Policy)
return false, errNetworkNotInitialized if err != nil {
return false, fmt.Errorf("failed to get native Policy hash: %w", err)
} }
result, err := c.InvokeFunction(c.cache.nativeHashes[nativenames.Policy], "isBlocked", []smartcontract.Parameter{{ result, err := c.InvokeFunction(policyHash, "isBlocked", []smartcontract.Parameter{{
Type: smartcontract.Hash160Type, Type: smartcontract.Hash160Type,
Value: hash, Value: hash,
}}, nil) }}, nil)

View file

@ -94,14 +94,15 @@ func (c *Client) getBlock(params request.RawParams) (*block.Block, error) {
err error err error
b *block.Block b *block.Block
) )
if !c.initDone {
return nil, errNetworkNotInitialized
}
if err = c.performRequest("getblock", params, &resp); err != nil { if err = c.performRequest("getblock", params, &resp); err != nil {
return nil, err return nil, err
} }
r := io.NewBinReaderFromBuf(resp) r := io.NewBinReaderFromBuf(resp)
b = block.New(c.StateRootInHeader()) sr, err := c.StateRootInHeader()
if err != nil {
return nil, err
}
b = block.New(sr)
b.DecodeBinary(r) b.DecodeBinary(r)
if r.Err != nil { if r.Err != nil {
return nil, r.Err return nil, r.Err
@ -127,9 +128,11 @@ func (c *Client) getBlockVerbose(params request.RawParams) (*result.Block, error
resp = &result.Block{} resp = &result.Block{}
err error err error
) )
if !c.initDone { sr, err := c.StateRootInHeader()
return nil, errNetworkNotInitialized if err != nil {
return nil, err
} }
resp.Header.StateRootEnabled = sr
if err = c.performRequest("getblock", params, resp); err != nil { if err = c.performRequest("getblock", params, resp); err != nil {
return nil, err return nil, err
} }
@ -157,14 +160,16 @@ func (c *Client) GetBlockHeader(hash util.Uint256) (*block.Header, error) {
resp []byte resp []byte
h *block.Header h *block.Header
) )
if !c.initDone {
return nil, errNetworkNotInitialized
}
if err := c.performRequest("getblockheader", params, &resp); err != nil { if err := c.performRequest("getblockheader", params, &resp); err != nil {
return nil, err return nil, err
} }
sr, err := c.StateRootInHeader()
if err != nil {
return nil, err
}
r := io.NewBinReaderFromBuf(resp) r := io.NewBinReaderFromBuf(resp)
h = new(block.Header) h = new(block.Header)
h.StateRootEnabled = sr
h.DecodeBinary(r) h.DecodeBinary(r)
if r.Err != nil { if r.Err != nil {
return nil, r.Err return nil, r.Err
@ -266,6 +271,14 @@ func (c *Client) GetNativeContracts() ([]state.NativeContract, error) {
if err := c.performRequest("getnativecontracts", params, &resp); err != nil { if err := c.performRequest("getnativecontracts", params, &resp); err != nil {
return resp, err return resp, err
} }
// Update native contract hashes.
c.cacheLock.Lock()
for _, cs := range resp {
c.cache.nativeHashes[cs.Manifest.Name] = cs.Hash
}
c.cacheLock.Unlock()
return resp, nil return resp, nil
} }
@ -399,17 +412,13 @@ func (c *Client) GetRawMemPool() ([]util.Uint256, error) {
return *resp, nil return *resp, nil
} }
// GetRawTransaction returns a transaction by hash. You should initialize network magic // GetRawTransaction returns a transaction by hash.
// with Init before calling GetRawTransaction.
func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction, error) { func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction, error) {
var ( var (
params = request.NewRawParams(hash.StringLE()) params = request.NewRawParams(hash.StringLE())
resp []byte resp []byte
err error err error
) )
if !c.initDone {
return nil, errNetworkNotInitialized
}
if err = c.performRequest("getrawtransaction", params, &resp); err != nil { if err = c.performRequest("getrawtransaction", params, &resp); err != nil {
return nil, err return nil, err
} }
@ -421,8 +430,7 @@ func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction,
} }
// GetRawTransactionVerbose returns a transaction wrapper with additional // GetRawTransactionVerbose returns a transaction wrapper with additional
// metadata by transaction's hash. You should initialize network magic // metadata by transaction's hash.
// with Init before calling GetRawTransactionVerbose.
// NOTE: to get transaction.ID and transaction.Size, use t.Hash() and io.GetVarSize(t) respectively. // NOTE: to get transaction.ID and transaction.Size, use t.Hash() and io.GetVarSize(t) respectively.
func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.TransactionOutputRaw, error) { func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.TransactionOutputRaw, error) {
var ( var (
@ -430,9 +438,6 @@ func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.Transactio
resp = &result.TransactionOutputRaw{} resp = &result.TransactionOutputRaw{}
err error err error
) )
if !c.initDone {
return nil, errNetworkNotInitialized
}
if err = c.performRequest("getrawtransaction", params, resp); err != nil { if err = c.performRequest("getrawtransaction", params, resp); err != nil {
return nil, err return nil, err
} }
@ -687,7 +692,11 @@ func (c *Client) SignAndPushTx(tx *transaction.Transaction, acc *wallet.Account,
txHash util.Uint256 txHash util.Uint256
err error err error
) )
if err = acc.SignTx(c.GetNetwork(), tx); err != nil { m, err := c.GetNetwork()
if err != nil {
return txHash, fmt.Errorf("failed to sign tx: %w", err)
}
if err = acc.SignTx(m, tx); err != nil {
return txHash, fmt.Errorf("failed to sign tx: %w", err) return txHash, fmt.Errorf("failed to sign tx: %w", err)
} }
// try to add witnesses for the rest of the signers // try to add witnesses for the rest of the signers
@ -695,7 +704,7 @@ func (c *Client) SignAndPushTx(tx *transaction.Transaction, acc *wallet.Account,
var isOk bool var isOk bool
for _, cosigner := range cosigners { for _, cosigner := range cosigners {
if signer.Account == cosigner.Signer.Account { if signer.Account == cosigner.Signer.Account {
err = cosigner.Account.SignTx(c.GetNetwork(), tx) err = cosigner.Account.SignTx(m, tx)
if err != nil { // then account is non-contract-based and locked, but let's provide more detailed error if err != nil { // then account is non-contract-based and locked, but let's provide more detailed error
if paramNum := len(cosigner.Account.Contract.Parameters); paramNum != 0 && cosigner.Account.Contract.Deployed { if paramNum := len(cosigner.Account.Contract.Parameters); paramNum != 0 && cosigner.Account.Contract.Deployed {
return txHash, fmt.Errorf("failed to add contract-based witness for signer #%d (%s): "+ return txHash, fmt.Errorf("failed to add contract-based witness for signer #%d (%s): "+
@ -771,9 +780,6 @@ func getSigners(sender *wallet.Account, cosigners []SignerAccount) ([]transactio
// Note: client should be initialized before SignAndPushP2PNotaryRequest call. // Note: client should be initialized before SignAndPushP2PNotaryRequest call.
func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fallbackScript []byte, fallbackSysFee int64, fallbackNetFee int64, fallbackValidFor uint32, acc *wallet.Account) (*payload.P2PNotaryRequest, error) { func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fallbackScript []byte, fallbackSysFee int64, fallbackNetFee int64, fallbackValidFor uint32, acc *wallet.Account) (*payload.P2PNotaryRequest, error) {
var err error var err error
if !c.initDone {
return nil, errNetworkNotInitialized
}
notaryHash, err := c.GetNativeContractHash(nativenames.Notary) notaryHash, err := c.GetNativeContractHash(nativenames.Notary)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get native Notary hash: %w", err) return nil, fmt.Errorf("failed to get native Notary hash: %w", err)
@ -835,7 +841,11 @@ func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fa
VerificationScript: []byte{}, VerificationScript: []byte{},
}, },
} }
if err = acc.SignTx(c.GetNetwork(), fallbackTx); err != nil { m, err := c.GetNetwork()
if err != nil {
return nil, fmt.Errorf("failed to sign fallback tx: %w", err)
}
if err = acc.SignTx(m, fallbackTx); err != nil {
return nil, fmt.Errorf("failed to sign fallback tx: %w", err) return nil, fmt.Errorf("failed to sign fallback tx: %w", err)
} }
fallbackHash := fallbackTx.Hash() fallbackHash := fallbackTx.Hash()
@ -844,7 +854,7 @@ func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fa
FallbackTransaction: fallbackTx, FallbackTransaction: fallbackTx,
} }
req.Witness = transaction.Witness{ req.Witness = transaction.Witness{
InvocationScript: append([]byte{byte(opcode.PUSHDATA1), 64}, acc.PrivateKey().SignHashable(uint32(c.GetNetwork()), req)...), InvocationScript: append([]byte{byte(opcode.PUSHDATA1), 64}, acc.PrivateKey().SignHashable(uint32(m), req)...),
VerificationScript: acc.GetVerificationScript(), VerificationScript: acc.GetVerificationScript(),
} }
actualHash, err := c.SubmitP2PNotaryRequest(req) actualHash, err := c.SubmitP2PNotaryRequest(req)
@ -922,18 +932,23 @@ func (c *Client) CalculateValidUntilBlock() (uint32, error) {
return result, fmt.Errorf("can't get block count: %w", err) return result, fmt.Errorf("can't get block count: %w", err)
} }
c.cacheLock.RLock()
if c.cache.calculateValidUntilBlock.expiresAt > blockCount { if c.cache.calculateValidUntilBlock.expiresAt > blockCount {
validatorsCount = c.cache.calculateValidUntilBlock.validatorsCount validatorsCount = c.cache.calculateValidUntilBlock.validatorsCount
c.cacheLock.RUnlock()
} else { } else {
c.cacheLock.RUnlock()
validators, err := c.GetNextBlockValidators() validators, err := c.GetNextBlockValidators()
if err != nil { if err != nil {
return result, fmt.Errorf("can't get validators: %w", err) return result, fmt.Errorf("can't get validators: %w", err)
} }
validatorsCount = uint32(len(validators)) validatorsCount = uint32(len(validators))
c.cacheLock.Lock()
c.cache.calculateValidUntilBlock = calculateValidUntilBlockCache{ c.cache.calculateValidUntilBlock = calculateValidUntilBlockCache{
validatorsCount: validatorsCount, validatorsCount: validatorsCount,
expiresAt: blockCount + cacheTimeout, expiresAt: blockCount + cacheTimeout,
} }
c.cacheLock.Unlock()
} }
return blockCount + validatorsCount + 1, nil return blockCount + validatorsCount + 1, nil
} }
@ -990,18 +1005,33 @@ func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs
} }
// GetNetwork returns the network magic of the RPC node client connected to. // GetNetwork returns the network magic of the RPC node client connected to.
func (c *Client) GetNetwork() netmode.Magic { func (c *Client) GetNetwork() (netmode.Magic, error) {
return c.network c.cacheLock.RLock()
defer c.cacheLock.RUnlock()
if !c.cache.initDone {
return 0, errNetworkNotInitialized
}
return c.cache.network, nil
} }
// StateRootInHeader returns true if state root is contained in block header. // StateRootInHeader returns true if state root is contained in block header.
func (c *Client) StateRootInHeader() bool { // You should initialize Client cache with Init() before calling StateRootInHeader.
return c.stateRootInHeader func (c *Client) StateRootInHeader() (bool, error) {
c.cacheLock.RLock()
defer c.cacheLock.RUnlock()
if !c.cache.initDone {
return false, errNetworkNotInitialized
}
return c.cache.stateRootInHeader, nil
} }
// GetNativeContractHash returns native contract hash by its name. // GetNativeContractHash returns native contract hash by its name.
func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) { func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
c.cacheLock.RLock()
hash, ok := c.cache.nativeHashes[name] hash, ok := c.cache.nativeHashes[name]
c.cacheLock.RUnlock()
if ok { if ok {
return hash, nil return hash, nil
} }
@ -1009,6 +1039,8 @@ func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
if err != nil { if err != nil {
return util.Uint160{}, err return util.Uint160{}, err
} }
c.cacheLock.Lock()
c.cache.nativeHashes[name] = cs.Hash c.cache.nativeHashes[name] = cs.Hash
c.cacheLock.Unlock()
return cs.Hash, nil return cs.Hash, nil
} }

View file

@ -6,6 +6,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"net/http" "net/http"
@ -1704,6 +1705,7 @@ func TestRPCClients(t *testing.T) {
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
c, err := New(ctx, endpoint, opts) c, err := New(ctx, endpoint, opts)
require.NoError(t, err) require.NoError(t, err)
c.getNextRequestID = getTestRequestID
require.NoError(t, c.Init()) require.NoError(t, c.Init())
return c, nil return c, nil
}) })
@ -1712,6 +1714,7 @@ func TestRPCClients(t *testing.T) {
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
return &wsc.Client, nil return &wsc.Client, nil
}) })
@ -1731,6 +1734,7 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.getNextRequestID = getTestRequestID
actual, err := testCase.invoke(c) actual, err := testCase.invoke(c)
if testCase.fails { if testCase.fails {
@ -1754,14 +1758,14 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
endpoint := srv.URL endpoint := srv.URL
opts := Options{} opts := Options{}
c, err := newClient(context.TODO(), endpoint, opts)
if err != nil {
t.Fatal(err)
}
for _, testCase := range testBatch { for _, testCase := range testBatch {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
_, err := testCase.invoke(c) c, err := newClient(context.TODO(), endpoint, opts)
if err != nil {
t.Fatal(err)
}
c.getNextRequestID = getTestRequestID
_, err = testCase.invoke(c)
assert.Error(t, err) assert.Error(t, err)
}) })
} }
@ -1877,6 +1881,7 @@ func TestCalculateValidUntilBlock(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.getNextRequestID = getTestRequestID
require.NoError(t, c.Init()) require.NoError(t, c.Init())
validUntilBlock, err := c.CalculateValidUntilBlock() validUntilBlock, err := c.CalculateValidUntilBlock()
@ -1912,9 +1917,11 @@ func TestGetNetwork(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.getNextRequestID = getTestRequestID
// network was not initialised // network was not initialised
require.Equal(t, netmode.Magic(0), c.GetNetwork()) _, err = c.GetNetwork()
require.Equal(t, false, c.initDone) require.True(t, errors.Is(err, errNetworkNotInitialized))
require.Equal(t, false, c.cache.initDone)
}) })
t.Run("good", func(t *testing.T) { t.Run("good", func(t *testing.T) {
@ -1922,8 +1929,11 @@ func TestGetNetwork(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.getNextRequestID = getTestRequestID
require.NoError(t, c.Init()) require.NoError(t, c.Init())
require.Equal(t, netmode.UnitTestNet, c.GetNetwork()) m, err := c.GetNetwork()
require.NoError(t, err)
require.Equal(t, netmode.UnitTestNet, m)
}) })
} }
@ -1941,6 +1951,7 @@ func TestUninitedClient(t *testing.T) {
c, err := New(context.TODO(), endpoint, opts) c, err := New(context.TODO(), endpoint, opts)
require.NoError(t, err) require.NoError(t, err)
c.getNextRequestID = getTestRequestID
_, err = c.GetBlockByIndex(0) _, err = c.GetBlockByIndex(0)
require.Error(t, err) require.Error(t, err)
@ -1966,3 +1977,7 @@ func newTestNEF(script []byte) nef.File {
ne.Checksum = ne.CalculateChecksum() ne.Checksum = ne.CalculateChecksum()
return ne return ne
} }
func getTestRequestID() uint64 {
return 1
}

View file

@ -4,6 +4,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"strconv"
"sync"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -20,6 +22,8 @@ import (
// servers. It's supposed to be faster than Client because it has persistent // servers. It's supposed to be faster than Client because it has persistent
// connection to the server and at the same time is exposes some functionality // connection to the server and at the same time is exposes some functionality
// that is only provided via websockets (like event subscription mechanism). // that is only provided via websockets (like event subscription mechanism).
// WSClient is thread-safe and can be used from multiple goroutines to perform
// RPC requests.
type WSClient struct { type WSClient struct {
Client Client
// Notifications is a channel that is used to send events received from // Notifications is a channel that is used to send events received from
@ -30,12 +34,16 @@ type WSClient struct {
// be closed, so make sure to handle this. // be closed, so make sure to handle this.
Notifications chan Notification Notifications chan Notification
ws *websocket.Conn ws *websocket.Conn
done chan struct{} done chan struct{}
responses chan *response.Raw requests chan *request.Raw
requests chan *request.Raw shutdown chan struct{}
shutdown chan struct{}
subscriptions map[string]bool subscriptionsLock sync.RWMutex
subscriptions map[string]bool
respLock sync.RWMutex
respChannels map[uint64]chan *response.Raw
} }
// Notification represents server-generated notification for client subscriptions. // Notification represents server-generated notification for client subscriptions.
@ -73,29 +81,29 @@ const (
// You should call Init method to initialize network magic the client is // You should call Init method to initialize network magic the client is
// operating on. // operating on.
func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) { func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) {
cl, err := New(ctx, endpoint, opts)
if err != nil {
return nil, err
}
cl.cli = nil
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
ws, _, err := dialer.Dial(endpoint, nil) ws, _, err := dialer.Dial(endpoint, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
wsc := &WSClient{ wsc := &WSClient{
Client: *cl, Client: Client{},
Notifications: make(chan Notification), Notifications: make(chan Notification),
ws: ws, ws: ws,
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
done: make(chan struct{}), done: make(chan struct{}),
responses: make(chan *response.Raw), respChannels: make(map[uint64]chan *response.Raw),
requests: make(chan *request.Raw), requests: make(chan *request.Raw),
subscriptions: make(map[string]bool), subscriptions: make(map[string]bool),
} }
err = initClient(ctx, &wsc.Client, endpoint, opts)
if err != nil {
return nil, err
}
wsc.Client.cli = nil
go wsc.wsReader() go wsc.wsReader()
go wsc.wsWriter() go wsc.wsWriter()
wsc.requestF = wsc.makeWsRequest wsc.requestF = wsc.makeWsRequest
@ -141,7 +149,12 @@ readloop:
var val interface{} var val interface{}
switch event { switch event {
case response.BlockEventID: case response.BlockEventID:
val = block.New(c.StateRootInHeader()) sr, err := c.StateRootInHeader()
if err != nil {
// Client is not initialised.
break
}
val = block.New(sr)
case response.TransactionEventID: case response.TransactionEventID:
val = &transaction.Transaction{} val = &transaction.Transaction{}
case response.NotificationEventID: case response.NotificationEventID:
@ -170,14 +183,27 @@ readloop:
resp.JSONRPC = rr.JSONRPC resp.JSONRPC = rr.JSONRPC
resp.Error = rr.Error resp.Error = rr.Error
resp.Result = rr.Result resp.Result = rr.Result
c.responses <- resp id, err := strconv.Atoi(string(resp.ID))
if err != nil {
break // Malformed response (invalid response ID).
}
ch := c.getResponseChannel(uint64(id))
if ch == nil {
break // Unknown response (unexpected response ID).
}
ch <- resp
} else { } else {
// Malformed response, neither valid request, nor valid response. // Malformed response, neither valid request, nor valid response.
break break
} }
} }
close(c.done) close(c.done)
close(c.responses) c.respLock.Lock()
for _, ch := range c.respChannels {
close(ch)
}
c.respChannels = nil
c.respLock.Unlock()
close(c.Notifications) close(c.Notifications)
} }
@ -212,16 +238,41 @@ func (c *WSClient) wsWriter() {
} }
} }
func (c *WSClient) registerRespChannel(id uint64, ch chan *response.Raw) {
c.respLock.Lock()
defer c.respLock.Unlock()
c.respChannels[id] = ch
}
func (c *WSClient) unregisterRespChannel(id uint64) {
c.respLock.Lock()
defer c.respLock.Unlock()
if ch, ok := c.respChannels[id]; ok {
delete(c.respChannels, id)
close(ch)
}
}
func (c *WSClient) getResponseChannel(id uint64) chan *response.Raw {
c.respLock.RLock()
defer c.respLock.RUnlock()
return c.respChannels[id]
}
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) { func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
ch := make(chan *response.Raw)
c.registerRespChannel(r.ID, ch)
select { select {
case <-c.done: case <-c.done:
return nil, errors.New("connection lost") return nil, errors.New("connection lost before sending the request")
case c.requests <- r: case c.requests <- r:
} }
select { select {
case <-c.done: case <-c.done:
return nil, errors.New("connection lost") return nil, errors.New("connection lost while waiting for the response")
case resp := <-c.responses: case resp := <-ch:
c.unregisterRespChannel(r.ID)
return resp, nil return resp, nil
} }
} }
@ -232,6 +283,10 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
if err := c.performRequest("subscribe", params, &resp); err != nil { if err := c.performRequest("subscribe", params, &resp); err != nil {
return "", err return "", err
} }
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
c.subscriptions[resp] = true c.subscriptions[resp] = true
return resp, nil return resp, nil
} }
@ -239,6 +294,9 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
func (c *WSClient) performUnsubscription(id string) error { func (c *WSClient) performUnsubscription(id string) error {
var resp bool var resp bool
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
if !c.subscriptions[id] { if !c.subscriptions[id] {
return errors.New("no subscription with this ID") return errors.New("no subscription with this ID")
} }
@ -320,11 +378,18 @@ func (c *WSClient) Unsubscribe(id string) error {
// UnsubscribeAll removes all active subscriptions of current client. // UnsubscribeAll removes all active subscriptions of current client.
func (c *WSClient) UnsubscribeAll() error { func (c *WSClient) UnsubscribeAll() error {
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
for id := range c.subscriptions { for id := range c.subscriptions {
err := c.performUnsubscription(id) var resp bool
if err != nil { if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil {
return err return err
} }
if !resp {
return errors.New("unsubscribe method returned false result")
}
delete(c.subscriptions, id)
} }
return nil return nil
} }

View file

@ -6,7 +6,10 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sort"
"strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -15,6 +18,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/atomic"
) )
func TestWSClientClose(t *testing.T) { func TestWSClientClose(t *testing.T) {
@ -45,6 +49,7 @@ func TestWSClientSubscription(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
id, err := f(wsc) id, err := f(wsc)
require.NoError(t, err) require.NoError(t, err)
@ -58,6 +63,7 @@ func TestWSClientSubscription(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
_, err = f(wsc) _, err = f(wsc)
require.Error(t, err) require.Error(t, err)
@ -107,6 +113,7 @@ func TestWSClientUnsubscription(t *testing.T) {
srv := initTestServer(t, rc.response) srv := initTestServer(t, rc.response)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
rc.code(t, wsc) rc.code(t, wsc)
}) })
@ -143,7 +150,9 @@ func TestWSClientEvents(t *testing.T) {
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.network = netmode.UnitTestNet wsc.getNextRequestID = getTestRequestID
wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually.
wsc.cache.network = netmode.UnitTestNet
for range events { for range events {
select { select {
case _, ok = <-wsc.Notifications: case _, ok = <-wsc.Notifications:
@ -166,6 +175,7 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
filter := "NONE" filter := "NONE"
_, err = wsc.SubscribeForTransactionExecutions(&filter) _, err = wsc.SubscribeForTransactionExecutions(&filter)
@ -315,7 +325,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
})) }))
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.network = netmode.UnitTestNet wsc.getNextRequestID = getTestRequestID
wsc.cache.network = netmode.UnitTestNet
c.clientCode(t, wsc) c.clientCode(t, wsc)
wsc.Close() wsc.Close()
}) })
@ -328,6 +339,8 @@ func TestNewWS(t *testing.T) {
t.Run("good", func(t *testing.T) { t.Run("good", func(t *testing.T) {
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
c.getNextRequestID = getTestRequestID
c.cache.network = netmode.UnitTestNet
require.NoError(t, c.Init()) require.NoError(t, c.Init())
}) })
t.Run("bad URL", func(t *testing.T) { t.Run("bad URL", func(t *testing.T) {
@ -335,3 +348,96 @@ func TestNewWS(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}) })
} }
func TestWSConcurrentAccess(t *testing.T) {
var ids struct {
lock sync.RWMutex
m map[int]struct{}
}
ids.m = make(map[int]struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/ws" && req.Method == "GET" {
var upgrader = websocket.Upgrader{}
ws, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
for {
err = ws.SetReadDeadline(time.Now().Add(2 * time.Second))
require.NoError(t, err)
_, p, err := ws.ReadMessage()
if err != nil {
break
}
r := request.NewIn()
err = json.Unmarshal(p, r)
if err != nil {
t.Fatalf("Cannot decode request body: %s", req.Body)
}
i, err := strconv.Atoi(string(r.RawID))
require.NoError(t, err)
ids.lock.Lock()
ids.m[i] = struct{}{}
ids.lock.Unlock()
var response string
// Different responses to catch possible unmarshalling errors connected with invalid IDs distribution.
switch r.Method {
case "getblockcount":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":123}`, r.RawID)
case "getversion":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":{"network":42,"tcpport":20332,"wsport":20342,"nonce":2153672787,"useragent":"/NEO-GO:0.73.1-pre-273-ge381358/"}}`, r.RawID)
case "getblockhash":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":"0x157ca5e5b8cf8f84c9660502a3270b346011612bded1514a6847f877c433a9bb"}`, r.RawID)
}
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second))
require.NoError(t, err)
err = ws.WriteMessage(1, []byte(response))
if err != nil {
break
}
}
ws.Close()
return
}
}))
t.Cleanup(srv.Close)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
batchCount := 100
completed := atomic.NewInt32(0)
for i := 0; i < batchCount; i++ {
go func() {
_, err := wsc.GetBlockCount()
require.NoError(t, err)
completed.Inc()
}()
go func() {
_, err := wsc.GetBlockHash(123)
require.NoError(t, err)
completed.Inc()
}()
go func() {
_, err := wsc.GetVersion()
require.NoError(t, err)
completed.Inc()
}()
}
require.Eventually(t, func() bool {
return int(completed.Load()) == batchCount*3
}, time.Second, 100*time.Millisecond)
ids.lock.RLock()
require.True(t, len(ids.m) > batchCount)
idsList := make([]int, 0, len(ids.m))
for i := range ids.m {
idsList = append(idsList, i)
}
ids.lock.RUnlock()
sort.Ints(idsList)
require.Equal(t, 1, idsList[0])
require.Less(t, idsList[len(idsList)-1],
batchCount*3+1) // batchCount*requestsPerBatch+1
wsc.Close()
}

View file

@ -32,12 +32,12 @@ func NewRawParams(vals ...interface{}) RawParams {
return p return p
} }
// Raw represents JSON-RPC request. // Raw represents JSON-RPC request on the Client side.
type Raw struct { type Raw struct {
JSONRPC string `json:"jsonrpc"` JSONRPC string `json:"jsonrpc"`
Method string `json:"method"` Method string `json:"method"`
RawParams []interface{} `json:"params"` RawParams []interface{} `json:"params"`
ID int `json:"id"` ID uint64 `json:"id"`
} }
// Request contains standard JSON-RPC 2.0 request and batch of // Request contains standard JSON-RPC 2.0 request and batch of

View file

@ -555,9 +555,11 @@ func TestSignAndPushP2PNotaryRequest(t *testing.T) {
c, err := client.New(context.Background(), httpSrv.URL, client.Options{}) c, err := client.New(context.Background(), httpSrv.URL, client.Options{})
require.NoError(t, err) require.NoError(t, err)
acc, err := wallet.NewAccount()
require.NoError(t, err)
t.Run("client wasn't initialized", func(t *testing.T) { t.Run("client wasn't initialized", func(t *testing.T) {
_, err := c.SignAndPushP2PNotaryRequest(nil, nil, 0, 0, 0, nil) _, err := c.SignAndPushP2PNotaryRequest(transaction.New([]byte{byte(opcode.RET)}, 123), []byte{byte(opcode.RET)}, -1, 0, 100, acc)
require.NotNil(t, err) require.NotNil(t, err)
}) })
@ -567,8 +569,6 @@ func TestSignAndPushP2PNotaryRequest(t *testing.T) {
require.NotNil(t, err) require.NotNil(t, err)
}) })
acc, err := wallet.NewAccount()
require.NoError(t, err)
t.Run("bad fallback script", func(t *testing.T) { t.Run("bad fallback script", func(t *testing.T) {
_, err := c.SignAndPushP2PNotaryRequest(nil, []byte{byte(opcode.ASSERT)}, -1, 0, 0, acc) _, err := c.SignAndPushP2PNotaryRequest(nil, []byte{byte(opcode.ASSERT)}, -1, 0, 0, acc)
require.NotNil(t, err) require.NotNil(t, err)
@ -649,7 +649,7 @@ func TestCalculateNotaryFee(t *testing.T) {
t.Run("client not initialized", func(t *testing.T) { t.Run("client not initialized", func(t *testing.T) {
_, err := c.CalculateNotaryFee(0) _, err := c.CalculateNotaryFee(0)
require.NotNil(t, err) require.NoError(t, err) // Do not require client initialisation for this.
}) })
} }