diff --git a/cli/smartcontract/smart_contract.go b/cli/smartcontract/smart_contract.go index aa8a19506..9e99b8686 100644 --- a/cli/smartcontract/smart_contract.go +++ b/cli/smartcontract/smart_contract.go @@ -682,7 +682,11 @@ func invokeWithArgs(ctx *cli.Context, acc *wallet.Account, wall *wallet.Wallet, if err != nil { return sender, cli.NewExitError(fmt.Errorf("failed to create tx: %w", err), 1) } - if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, out); err != nil { + m, err := c.GetNetwork() + if err != nil { + return sender, cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1) + } + if err := paramcontext.InitAndSave(m, tx, acc, out); err != nil { return sender, cli.NewExitError(err, 1) } fmt.Fprintln(ctx.App.Writer, tx.Hash().StringLE()) diff --git a/cli/wallet/nep11.go b/cli/wallet/nep11.go index d5adb21c0..335d35bf9 100644 --- a/cli/wallet/nep11.go +++ b/cli/wallet/nep11.go @@ -274,7 +274,11 @@ func signAndSendNEP11Transfer(ctx *cli.Context, c *client.Client, acc *wallet.Ac tx.SystemFee += int64(sysgas) if outFile := ctx.String("out"); outFile != "" { - if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, outFile); err != nil { + m, err := c.GetNetwork() + if err != nil { + return cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1) + } + if err := paramcontext.InitAndSave(m, tx, acc, outFile); err != nil { return cli.NewExitError(err, 1) } } else { diff --git a/cli/wallet/nep17.go b/cli/wallet/nep17.go index 417c488fe..a11766d83 100644 --- a/cli/wallet/nep17.go +++ b/cli/wallet/nep17.go @@ -647,7 +647,11 @@ func signAndSendNEP17Transfer(ctx *cli.Context, c *client.Client, acc *wallet.Ac tx.SystemFee += int64(sysgas) if outFile := ctx.String("out"); outFile != "" { - if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, outFile); err != nil { + m, err := c.GetNetwork() + if err != nil { + return cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1) + } + if err := paramcontext.InitAndSave(m, tx, acc, outFile); err != nil { return cli.NewExitError(err, 1) } } else { diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index 62d7612af..1337b6d52 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/url" + "sync" "time" "github.com/nspcc-dev/neo-go/pkg/config/netmode" @@ -16,6 +17,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response" "github.com/nspcc-dev/neo-go/pkg/util" + "go.uber.org/atomic" ) const ( @@ -26,17 +28,26 @@ const ( ) // Client represents the middleman for executing JSON RPC calls -// to remote NEO RPC nodes. +// to remote NEO RPC nodes. Client is thread-safe and can be used from +// multiple goroutines. type Client struct { - cli *http.Client - endpoint *url.URL - network netmode.Magic - stateRootInHeader bool - initDone bool - ctx context.Context - opts Options - requestF func(*request.Raw) (*response.Raw, error) - cache cache + cli *http.Client + endpoint *url.URL + ctx context.Context + opts Options + requestF func(*request.Raw) (*response.Raw, error) + + cacheLock sync.RWMutex + // cache stores RPC node related information client is bound to. + // cache is mostly filled in during Init(), but can also be updated + // during regular Client lifecycle. + cache cache + + latestReqID *atomic.Uint64 + // getNextRequestID returns ID to be used for subsequent request creation. + // It is defined on Client so that our testing code can override this method + // for the sake of more predictable request IDs generation behaviour. + getNextRequestID func() uint64 } // Options defines options for the RPC client. @@ -56,6 +67,9 @@ type Options struct { // cache stores cache values for the RPC client methods. type cache struct { + initDone bool + network netmode.Magic + stateRootInHeader bool calculateValidUntilBlock calculateValidUntilBlockCache nativeHashes map[string]util.Uint160 } @@ -70,10 +84,19 @@ type calculateValidUntilBlockCache struct { // New returns a new Client ready to use. You should call Init method to // initialize network magic the client is operating on. func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { - url, err := url.Parse(endpoint) + cl := new(Client) + err := initClient(ctx, cl, endpoint, opts) if err != nil { return nil, err } + return cl, nil +} + +func initClient(ctx context.Context, cl *Client, endpoint string, opts Options) error { + url, err := url.Parse(endpoint) + if err != nil { + return err + } if opts.DialTimeout <= 0 { opts.DialTimeout = defaultDialTimeout @@ -97,33 +120,41 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { // if opts.Cert != "" && opts.Key != "" { // } - cl := &Client{ - ctx: ctx, - cli: httpClient, - endpoint: url, - cache: cache{ - nativeHashes: make(map[string]util.Uint160), - }, + cl.ctx = ctx + cl.cli = httpClient + cl.endpoint = url + cl.cache = cache{ + nativeHashes: make(map[string]util.Uint160), } + cl.latestReqID = atomic.NewUint64(0) + cl.getNextRequestID = (cl).getRequestID cl.opts = opts cl.requestF = cl.makeHTTPRequest - return cl, nil + return nil +} + +func (c *Client) getRequestID() uint64 { + return c.latestReqID.Inc() } // Init sets magic of the network client connected to, stateRootInHeader option // and native NEO, GAS and Policy contracts scripthashes. This method should be -// called before any transaction-, header- or block-related requests in order to -// deserialize responses properly. +// called before any header- or block-related requests in order to deserialize +// responses properly. func (c *Client) Init() error { version, err := c.GetVersion() if err != nil { return fmt.Errorf("failed to get network magic: %w", err) } - c.network = version.Protocol.Network - c.stateRootInHeader = version.Protocol.StateRootInHeader + + c.cacheLock.Lock() + defer c.cacheLock.Unlock() + + c.cache.network = version.Protocol.Network + c.cache.stateRootInHeader = version.Protocol.StateRootInHeader if version.Protocol.MillisecondsPerBlock == 0 { - c.network = version.Magic - c.stateRootInHeader = version.StateRootInHeader + c.cache.network = version.Magic + c.cache.stateRootInHeader = version.StateRootInHeader } neoContractHash, err := c.GetContractStateByAddressOrName(nativenames.Neo) if err != nil { @@ -140,7 +171,7 @@ func (c *Client) Init() error { return fmt.Errorf("failed to get Policy contract scripthash: %w", err) } c.cache.nativeHashes[nativenames.Policy] = policyContractHash.Hash - c.initDone = true + c.cache.initDone = true return nil } @@ -149,7 +180,7 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{ JSONRPC: request.JSONRPCVersion, Method: method, RawParams: p.Values, - ID: 1, + ID: c.getNextRequestID(), } raw, err := c.requestF(&r) diff --git a/pkg/rpc/client/nep11.go b/pkg/rpc/client/nep11.go index 1a744e435..40217f517 100644 --- a/pkg/rpc/client/nep11.go +++ b/pkg/rpc/client/nep11.go @@ -46,9 +46,6 @@ func (c *Client) NEP11TokenInfo(tokenHash util.Uint160) (*wallet.Token, error) { // given account and sends it to the network returning just a hash of it. func (c *Client) TransferNEP11(acc *wallet.Account, to util.Uint160, tokenHash util.Uint160, tokenID string, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) { - if !c.initDone { - return util.Uint256{}, errNetworkNotInitialized - } tx, err := c.CreateNEP11TransferTx(acc, tokenHash, gas, cosigners, to, tokenID, data) if err != nil { return util.Uint256{}, err @@ -144,9 +141,6 @@ func (c *Client) NEP11NDOwnerOf(tokenHash util.Uint160, tokenID []byte) (util.Ui // sends it to the network returning just a hash of it. func (c *Client) TransferNEP11D(acc *wallet.Account, to util.Uint160, tokenHash util.Uint160, amount int64, tokenID []byte, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) { - if !c.initDone { - return util.Uint256{}, errNetworkNotInitialized - } from, err := address.StringToUint160(acc.Address) if err != nil { return util.Uint256{}, fmt.Errorf("bad account address: %w", err) diff --git a/pkg/rpc/client/nep17.go b/pkg/rpc/client/nep17.go index ee7685c20..22dcf6ea6 100644 --- a/pkg/rpc/client/nep17.go +++ b/pkg/rpc/client/nep17.go @@ -142,10 +142,6 @@ func (c *Client) CreateTxFromScript(script []byte, acc *wallet.Account, sysFee, // impossible (e.g. due to locked cosigner's account) an error is returned. func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util.Uint160, amount int64, gas int64, data interface{}, cosigners []SignerAccount) (util.Uint256, error) { - if !c.initDone { - return util.Uint256{}, errNetworkNotInitialized - } - tx, err := c.CreateNEP17TransferTx(acc, to, token, amount, gas, data, cosigners) if err != nil { return util.Uint256{}, err @@ -156,10 +152,6 @@ func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util. // MultiTransferNEP17 is similar to TransferNEP17, buf allows to have multiple recipients. func (c *Client) MultiTransferNEP17(acc *wallet.Account, gas int64, recipients []TransferTarget, cosigners []SignerAccount) (util.Uint256, error) { - if !c.initDone { - return util.Uint256{}, errNetworkNotInitialized - } - tx, err := c.CreateNEP17MultiTransferTx(acc, gas, recipients, cosigners) if err != nil { return util.Uint256{}, err diff --git a/pkg/rpc/client/policy.go b/pkg/rpc/client/policy.go index e8644a025..8fe8f0a2b 100644 --- a/pkg/rpc/client/policy.go +++ b/pkg/rpc/client/policy.go @@ -10,25 +10,16 @@ import ( // GetFeePerByte invokes `getFeePerByte` method on a native Policy contract. func (c *Client) GetFeePerByte() (int64, error) { - if !c.initDone { - return 0, errNetworkNotInitialized - } return c.invokeNativePolicyMethod("getFeePerByte") } // GetExecFeeFactor invokes `getExecFeeFactor` method on a native Policy contract. func (c *Client) GetExecFeeFactor() (int64, error) { - if !c.initDone { - return 0, errNetworkNotInitialized - } return c.invokeNativePolicyMethod("getExecFeeFactor") } // GetStoragePrice invokes `getStoragePrice` method on a native Policy contract. func (c *Client) GetStoragePrice() (int64, error) { - if !c.initDone { - return 0, errNetworkNotInitialized - } return c.invokeNativePolicyMethod("getStoragePrice") } @@ -43,10 +34,11 @@ func (c *Client) GetMaxNotValidBeforeDelta() (int64, error) { // invokeNativePolicy method invokes Get* method on a native Policy contract. func (c *Client) invokeNativePolicyMethod(operation string) (int64, error) { - if !c.initDone { - return 0, errNetworkNotInitialized + policyHash, err := c.GetNativeContractHash(nativenames.Policy) + if err != nil { + return 0, fmt.Errorf("failed to get native Policy hash: %w", err) } - return c.invokeNativeGetMethod(c.cache.nativeHashes[nativenames.Policy], operation) + return c.invokeNativeGetMethod(policyHash, operation) } func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int64, error) { @@ -63,10 +55,11 @@ func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int // IsBlocked invokes `isBlocked` method on native Policy contract. func (c *Client) IsBlocked(hash util.Uint160) (bool, error) { - if !c.initDone { - return false, errNetworkNotInitialized + policyHash, err := c.GetNativeContractHash(nativenames.Policy) + if err != nil { + return false, fmt.Errorf("failed to get native Policy hash: %w", err) } - result, err := c.InvokeFunction(c.cache.nativeHashes[nativenames.Policy], "isBlocked", []smartcontract.Parameter{{ + result, err := c.InvokeFunction(policyHash, "isBlocked", []smartcontract.Parameter{{ Type: smartcontract.Hash160Type, Value: hash, }}, nil) diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index 086876070..3d035d40f 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -94,14 +94,15 @@ func (c *Client) getBlock(params request.RawParams) (*block.Block, error) { err error b *block.Block ) - if !c.initDone { - return nil, errNetworkNotInitialized - } if err = c.performRequest("getblock", params, &resp); err != nil { return nil, err } r := io.NewBinReaderFromBuf(resp) - b = block.New(c.StateRootInHeader()) + sr, err := c.StateRootInHeader() + if err != nil { + return nil, err + } + b = block.New(sr) b.DecodeBinary(r) if r.Err != nil { return nil, r.Err @@ -127,9 +128,11 @@ func (c *Client) getBlockVerbose(params request.RawParams) (*result.Block, error resp = &result.Block{} err error ) - if !c.initDone { - return nil, errNetworkNotInitialized + sr, err := c.StateRootInHeader() + if err != nil { + return nil, err } + resp.Header.StateRootEnabled = sr if err = c.performRequest("getblock", params, resp); err != nil { return nil, err } @@ -157,14 +160,16 @@ func (c *Client) GetBlockHeader(hash util.Uint256) (*block.Header, error) { resp []byte h *block.Header ) - if !c.initDone { - return nil, errNetworkNotInitialized - } if err := c.performRequest("getblockheader", params, &resp); err != nil { return nil, err } + sr, err := c.StateRootInHeader() + if err != nil { + return nil, err + } r := io.NewBinReaderFromBuf(resp) h = new(block.Header) + h.StateRootEnabled = sr h.DecodeBinary(r) if r.Err != nil { return nil, r.Err @@ -266,6 +271,14 @@ func (c *Client) GetNativeContracts() ([]state.NativeContract, error) { if err := c.performRequest("getnativecontracts", params, &resp); err != nil { return resp, err } + + // Update native contract hashes. + c.cacheLock.Lock() + for _, cs := range resp { + c.cache.nativeHashes[cs.Manifest.Name] = cs.Hash + } + c.cacheLock.Unlock() + return resp, nil } @@ -399,17 +412,13 @@ func (c *Client) GetRawMemPool() ([]util.Uint256, error) { return *resp, nil } -// GetRawTransaction returns a transaction by hash. You should initialize network magic -// with Init before calling GetRawTransaction. +// GetRawTransaction returns a transaction by hash. func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction, error) { var ( params = request.NewRawParams(hash.StringLE()) resp []byte err error ) - if !c.initDone { - return nil, errNetworkNotInitialized - } if err = c.performRequest("getrawtransaction", params, &resp); err != nil { return nil, err } @@ -421,8 +430,7 @@ func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction, } // GetRawTransactionVerbose returns a transaction wrapper with additional -// metadata by transaction's hash. You should initialize network magic -// with Init before calling GetRawTransactionVerbose. +// metadata by transaction's hash. // NOTE: to get transaction.ID and transaction.Size, use t.Hash() and io.GetVarSize(t) respectively. func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.TransactionOutputRaw, error) { var ( @@ -430,9 +438,6 @@ func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.Transactio resp = &result.TransactionOutputRaw{} err error ) - if !c.initDone { - return nil, errNetworkNotInitialized - } if err = c.performRequest("getrawtransaction", params, resp); err != nil { return nil, err } @@ -687,7 +692,11 @@ func (c *Client) SignAndPushTx(tx *transaction.Transaction, acc *wallet.Account, txHash util.Uint256 err error ) - if err = acc.SignTx(c.GetNetwork(), tx); err != nil { + m, err := c.GetNetwork() + if err != nil { + return txHash, fmt.Errorf("failed to sign tx: %w", err) + } + if err = acc.SignTx(m, tx); err != nil { return txHash, fmt.Errorf("failed to sign tx: %w", err) } // try to add witnesses for the rest of the signers @@ -695,7 +704,7 @@ func (c *Client) SignAndPushTx(tx *transaction.Transaction, acc *wallet.Account, var isOk bool for _, cosigner := range cosigners { if signer.Account == cosigner.Signer.Account { - err = cosigner.Account.SignTx(c.GetNetwork(), tx) + err = cosigner.Account.SignTx(m, tx) if err != nil { // then account is non-contract-based and locked, but let's provide more detailed error if paramNum := len(cosigner.Account.Contract.Parameters); paramNum != 0 && cosigner.Account.Contract.Deployed { return txHash, fmt.Errorf("failed to add contract-based witness for signer #%d (%s): "+ @@ -771,9 +780,6 @@ func getSigners(sender *wallet.Account, cosigners []SignerAccount) ([]transactio // Note: client should be initialized before SignAndPushP2PNotaryRequest call. func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fallbackScript []byte, fallbackSysFee int64, fallbackNetFee int64, fallbackValidFor uint32, acc *wallet.Account) (*payload.P2PNotaryRequest, error) { var err error - if !c.initDone { - return nil, errNetworkNotInitialized - } notaryHash, err := c.GetNativeContractHash(nativenames.Notary) if err != nil { return nil, fmt.Errorf("failed to get native Notary hash: %w", err) @@ -835,7 +841,11 @@ func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fa VerificationScript: []byte{}, }, } - if err = acc.SignTx(c.GetNetwork(), fallbackTx); err != nil { + m, err := c.GetNetwork() + if err != nil { + return nil, fmt.Errorf("failed to sign fallback tx: %w", err) + } + if err = acc.SignTx(m, fallbackTx); err != nil { return nil, fmt.Errorf("failed to sign fallback tx: %w", err) } fallbackHash := fallbackTx.Hash() @@ -844,7 +854,7 @@ func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fa FallbackTransaction: fallbackTx, } req.Witness = transaction.Witness{ - InvocationScript: append([]byte{byte(opcode.PUSHDATA1), 64}, acc.PrivateKey().SignHashable(uint32(c.GetNetwork()), req)...), + InvocationScript: append([]byte{byte(opcode.PUSHDATA1), 64}, acc.PrivateKey().SignHashable(uint32(m), req)...), VerificationScript: acc.GetVerificationScript(), } actualHash, err := c.SubmitP2PNotaryRequest(req) @@ -922,18 +932,23 @@ func (c *Client) CalculateValidUntilBlock() (uint32, error) { return result, fmt.Errorf("can't get block count: %w", err) } + c.cacheLock.RLock() if c.cache.calculateValidUntilBlock.expiresAt > blockCount { validatorsCount = c.cache.calculateValidUntilBlock.validatorsCount + c.cacheLock.RUnlock() } else { + c.cacheLock.RUnlock() validators, err := c.GetNextBlockValidators() if err != nil { return result, fmt.Errorf("can't get validators: %w", err) } validatorsCount = uint32(len(validators)) + c.cacheLock.Lock() c.cache.calculateValidUntilBlock = calculateValidUntilBlockCache{ validatorsCount: validatorsCount, expiresAt: blockCount + cacheTimeout, } + c.cacheLock.Unlock() } return blockCount + validatorsCount + 1, nil } @@ -990,18 +1005,33 @@ func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs } // GetNetwork returns the network magic of the RPC node client connected to. -func (c *Client) GetNetwork() netmode.Magic { - return c.network +func (c *Client) GetNetwork() (netmode.Magic, error) { + c.cacheLock.RLock() + defer c.cacheLock.RUnlock() + + if !c.cache.initDone { + return 0, errNetworkNotInitialized + } + return c.cache.network, nil } // StateRootInHeader returns true if state root is contained in block header. -func (c *Client) StateRootInHeader() bool { - return c.stateRootInHeader +// You should initialize Client cache with Init() before calling StateRootInHeader. +func (c *Client) StateRootInHeader() (bool, error) { + c.cacheLock.RLock() + defer c.cacheLock.RUnlock() + + if !c.cache.initDone { + return false, errNetworkNotInitialized + } + return c.cache.stateRootInHeader, nil } // GetNativeContractHash returns native contract hash by its name. func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) { + c.cacheLock.RLock() hash, ok := c.cache.nativeHashes[name] + c.cacheLock.RUnlock() if ok { return hash, nil } @@ -1009,6 +1039,8 @@ func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) { if err != nil { return util.Uint160{}, err } + c.cacheLock.Lock() c.cache.nativeHashes[name] = cs.Hash + c.cacheLock.Unlock() return cs.Hash, nil } diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index 51c008785..1e58d2857 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "fmt" "math/big" "net/http" @@ -1704,6 +1705,7 @@ func TestRPCClients(t *testing.T) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { c, err := New(ctx, endpoint, opts) require.NoError(t, err) + c.getNextRequestID = getTestRequestID require.NoError(t, c.Init()) return c, nil }) @@ -1712,6 +1714,7 @@ func TestRPCClients(t *testing.T) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) return &wsc.Client, nil }) @@ -1731,6 +1734,7 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options if err != nil { t.Fatal(err) } + c.getNextRequestID = getTestRequestID actual, err := testCase.invoke(c) if testCase.fails { @@ -1754,14 +1758,14 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options endpoint := srv.URL opts := Options{} - c, err := newClient(context.TODO(), endpoint, opts) - if err != nil { - t.Fatal(err) - } - for _, testCase := range testBatch { t.Run(testCase.name, func(t *testing.T) { - _, err := testCase.invoke(c) + c, err := newClient(context.TODO(), endpoint, opts) + if err != nil { + t.Fatal(err) + } + c.getNextRequestID = getTestRequestID + _, err = testCase.invoke(c) assert.Error(t, err) }) } @@ -1877,6 +1881,7 @@ func TestCalculateValidUntilBlock(t *testing.T) { if err != nil { t.Fatal(err) } + c.getNextRequestID = getTestRequestID require.NoError(t, c.Init()) validUntilBlock, err := c.CalculateValidUntilBlock() @@ -1912,9 +1917,11 @@ func TestGetNetwork(t *testing.T) { if err != nil { t.Fatal(err) } + c.getNextRequestID = getTestRequestID // network was not initialised - require.Equal(t, netmode.Magic(0), c.GetNetwork()) - require.Equal(t, false, c.initDone) + _, err = c.GetNetwork() + require.True(t, errors.Is(err, errNetworkNotInitialized)) + require.Equal(t, false, c.cache.initDone) }) t.Run("good", func(t *testing.T) { @@ -1922,8 +1929,11 @@ func TestGetNetwork(t *testing.T) { if err != nil { t.Fatal(err) } + c.getNextRequestID = getTestRequestID require.NoError(t, c.Init()) - require.Equal(t, netmode.UnitTestNet, c.GetNetwork()) + m, err := c.GetNetwork() + require.NoError(t, err) + require.Equal(t, netmode.UnitTestNet, m) }) } @@ -1941,6 +1951,7 @@ func TestUninitedClient(t *testing.T) { c, err := New(context.TODO(), endpoint, opts) require.NoError(t, err) + c.getNextRequestID = getTestRequestID _, err = c.GetBlockByIndex(0) require.Error(t, err) @@ -1966,3 +1977,7 @@ func newTestNEF(script []byte) nef.File { ne.Checksum = ne.CalculateChecksum() return ne } + +func getTestRequestID() uint64 { + return 1 +} diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go index f2ed6aed8..89bd4b934 100644 --- a/pkg/rpc/client/wsclient.go +++ b/pkg/rpc/client/wsclient.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "errors" + "strconv" + "sync" "time" "github.com/gorilla/websocket" @@ -20,6 +22,8 @@ import ( // servers. It's supposed to be faster than Client because it has persistent // connection to the server and at the same time is exposes some functionality // that is only provided via websockets (like event subscription mechanism). +// WSClient is thread-safe and can be used from multiple goroutines to perform +// RPC requests. type WSClient struct { Client // Notifications is a channel that is used to send events received from @@ -30,12 +34,16 @@ type WSClient struct { // be closed, so make sure to handle this. Notifications chan Notification - ws *websocket.Conn - done chan struct{} - responses chan *response.Raw - requests chan *request.Raw - shutdown chan struct{} - subscriptions map[string]bool + ws *websocket.Conn + done chan struct{} + requests chan *request.Raw + shutdown chan struct{} + + subscriptionsLock sync.RWMutex + subscriptions map[string]bool + + respLock sync.RWMutex + respChannels map[uint64]chan *response.Raw } // Notification represents server-generated notification for client subscriptions. @@ -73,29 +81,29 @@ const ( // You should call Init method to initialize network magic the client is // operating on. func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) { - cl, err := New(ctx, endpoint, opts) - if err != nil { - return nil, err - } - - cl.cli = nil - dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} ws, _, err := dialer.Dial(endpoint, nil) if err != nil { return nil, err } wsc := &WSClient{ - Client: *cl, + Client: Client{}, Notifications: make(chan Notification), ws: ws, shutdown: make(chan struct{}), done: make(chan struct{}), - responses: make(chan *response.Raw), + respChannels: make(map[uint64]chan *response.Raw), requests: make(chan *request.Raw), subscriptions: make(map[string]bool), } + + err = initClient(ctx, &wsc.Client, endpoint, opts) + if err != nil { + return nil, err + } + wsc.Client.cli = nil + go wsc.wsReader() go wsc.wsWriter() wsc.requestF = wsc.makeWsRequest @@ -141,7 +149,12 @@ readloop: var val interface{} switch event { case response.BlockEventID: - val = block.New(c.StateRootInHeader()) + sr, err := c.StateRootInHeader() + if err != nil { + // Client is not initialised. + break + } + val = block.New(sr) case response.TransactionEventID: val = &transaction.Transaction{} case response.NotificationEventID: @@ -170,14 +183,27 @@ readloop: resp.JSONRPC = rr.JSONRPC resp.Error = rr.Error resp.Result = rr.Result - c.responses <- resp + id, err := strconv.Atoi(string(resp.ID)) + if err != nil { + break // Malformed response (invalid response ID). + } + ch := c.getResponseChannel(uint64(id)) + if ch == nil { + break // Unknown response (unexpected response ID). + } + ch <- resp } else { // Malformed response, neither valid request, nor valid response. break } } close(c.done) - close(c.responses) + c.respLock.Lock() + for _, ch := range c.respChannels { + close(ch) + } + c.respChannels = nil + c.respLock.Unlock() close(c.Notifications) } @@ -212,16 +238,41 @@ func (c *WSClient) wsWriter() { } } +func (c *WSClient) registerRespChannel(id uint64, ch chan *response.Raw) { + c.respLock.Lock() + defer c.respLock.Unlock() + c.respChannels[id] = ch +} + +func (c *WSClient) unregisterRespChannel(id uint64) { + c.respLock.Lock() + defer c.respLock.Unlock() + if ch, ok := c.respChannels[id]; ok { + delete(c.respChannels, id) + close(ch) + } +} + +func (c *WSClient) getResponseChannel(id uint64) chan *response.Raw { + c.respLock.RLock() + defer c.respLock.RUnlock() + return c.respChannels[id] +} + func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) { + ch := make(chan *response.Raw) + c.registerRespChannel(r.ID, ch) + select { case <-c.done: - return nil, errors.New("connection lost") + return nil, errors.New("connection lost before sending the request") case c.requests <- r: } select { case <-c.done: - return nil, errors.New("connection lost") - case resp := <-c.responses: + return nil, errors.New("connection lost while waiting for the response") + case resp := <-ch: + c.unregisterRespChannel(r.ID) return resp, nil } } @@ -232,6 +283,10 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error) if err := c.performRequest("subscribe", params, &resp); err != nil { return "", err } + + c.subscriptionsLock.Lock() + defer c.subscriptionsLock.Unlock() + c.subscriptions[resp] = true return resp, nil } @@ -239,6 +294,9 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error) func (c *WSClient) performUnsubscription(id string) error { var resp bool + c.subscriptionsLock.Lock() + defer c.subscriptionsLock.Unlock() + if !c.subscriptions[id] { return errors.New("no subscription with this ID") } @@ -320,11 +378,18 @@ func (c *WSClient) Unsubscribe(id string) error { // UnsubscribeAll removes all active subscriptions of current client. func (c *WSClient) UnsubscribeAll() error { + c.subscriptionsLock.Lock() + defer c.subscriptionsLock.Unlock() + for id := range c.subscriptions { - err := c.performUnsubscription(id) - if err != nil { + var resp bool + if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil { return err } + if !resp { + return errors.New("unsubscribe method returned false result") + } + delete(c.subscriptions, id) } return nil } diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 2f43e77b6..0945aff77 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -6,7 +6,10 @@ import ( "fmt" "net/http" "net/http/httptest" + "sort" + "strconv" "strings" + "sync" "testing" "time" @@ -15,6 +18,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) func TestWSClientClose(t *testing.T) { @@ -45,6 +49,7 @@ func TestWSClientSubscription(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) id, err := f(wsc) require.NoError(t, err) @@ -58,6 +63,7 @@ func TestWSClientSubscription(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) _, err = f(wsc) require.Error(t, err) @@ -107,6 +113,7 @@ func TestWSClientUnsubscription(t *testing.T) { srv := initTestServer(t, rc.response) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) rc.code(t, wsc) }) @@ -143,7 +150,9 @@ func TestWSClientEvents(t *testing.T) { wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) - wsc.network = netmode.UnitTestNet + wsc.getNextRequestID = getTestRequestID + wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually. + wsc.cache.network = netmode.UnitTestNet for range events { select { case _, ok = <-wsc.Notifications: @@ -166,6 +175,7 @@ func TestWSExecutionVMStateCheck(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) filter := "NONE" _, err = wsc.SubscribeForTransactionExecutions(&filter) @@ -315,7 +325,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { })) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) - wsc.network = netmode.UnitTestNet + wsc.getNextRequestID = getTestRequestID + wsc.cache.network = netmode.UnitTestNet c.clientCode(t, wsc) wsc.Close() }) @@ -328,6 +339,8 @@ func TestNewWS(t *testing.T) { t.Run("good", func(t *testing.T) { c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + c.getNextRequestID = getTestRequestID + c.cache.network = netmode.UnitTestNet require.NoError(t, c.Init()) }) t.Run("bad URL", func(t *testing.T) { @@ -335,3 +348,96 @@ func TestNewWS(t *testing.T) { require.Error(t, err) }) } + +func TestWSConcurrentAccess(t *testing.T) { + var ids struct { + lock sync.RWMutex + m map[int]struct{} + } + ids.m = make(map[int]struct{}) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ws" && req.Method == "GET" { + var upgrader = websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + for { + err = ws.SetReadDeadline(time.Now().Add(2 * time.Second)) + require.NoError(t, err) + _, p, err := ws.ReadMessage() + if err != nil { + break + } + r := request.NewIn() + err = json.Unmarshal(p, r) + if err != nil { + t.Fatalf("Cannot decode request body: %s", req.Body) + } + i, err := strconv.Atoi(string(r.RawID)) + require.NoError(t, err) + ids.lock.Lock() + ids.m[i] = struct{}{} + ids.lock.Unlock() + var response string + // Different responses to catch possible unmarshalling errors connected with invalid IDs distribution. + switch r.Method { + case "getblockcount": + response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":123}`, r.RawID) + case "getversion": + response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":{"network":42,"tcpport":20332,"wsport":20342,"nonce":2153672787,"useragent":"/NEO-GO:0.73.1-pre-273-ge381358/"}}`, r.RawID) + case "getblockhash": + response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":"0x157ca5e5b8cf8f84c9660502a3270b346011612bded1514a6847f877c433a9bb"}`, r.RawID) + } + err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) + require.NoError(t, err) + err = ws.WriteMessage(1, []byte(response)) + if err != nil { + break + } + } + ws.Close() + return + } + })) + t.Cleanup(srv.Close) + + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + batchCount := 100 + completed := atomic.NewInt32(0) + for i := 0; i < batchCount; i++ { + go func() { + _, err := wsc.GetBlockCount() + require.NoError(t, err) + completed.Inc() + }() + go func() { + _, err := wsc.GetBlockHash(123) + require.NoError(t, err) + completed.Inc() + }() + + go func() { + _, err := wsc.GetVersion() + require.NoError(t, err) + completed.Inc() + }() + } + require.Eventually(t, func() bool { + return int(completed.Load()) == batchCount*3 + }, time.Second, 100*time.Millisecond) + + ids.lock.RLock() + require.True(t, len(ids.m) > batchCount) + idsList := make([]int, 0, len(ids.m)) + for i := range ids.m { + idsList = append(idsList, i) + } + ids.lock.RUnlock() + + sort.Ints(idsList) + require.Equal(t, 1, idsList[0]) + require.Less(t, idsList[len(idsList)-1], + batchCount*3+1) // batchCount*requestsPerBatch+1 + wsc.Close() +} diff --git a/pkg/rpc/request/types.go b/pkg/rpc/request/types.go index 654771060..2281ef3ae 100644 --- a/pkg/rpc/request/types.go +++ b/pkg/rpc/request/types.go @@ -32,12 +32,12 @@ func NewRawParams(vals ...interface{}) RawParams { return p } -// Raw represents JSON-RPC request. +// Raw represents JSON-RPC request on the Client side. type Raw struct { JSONRPC string `json:"jsonrpc"` Method string `json:"method"` RawParams []interface{} `json:"params"` - ID int `json:"id"` + ID uint64 `json:"id"` } // Request contains standard JSON-RPC 2.0 request and batch of diff --git a/pkg/rpc/server/client_test.go b/pkg/rpc/server/client_test.go index 18c2b418d..4e9c8c4d3 100644 --- a/pkg/rpc/server/client_test.go +++ b/pkg/rpc/server/client_test.go @@ -555,9 +555,11 @@ func TestSignAndPushP2PNotaryRequest(t *testing.T) { c, err := client.New(context.Background(), httpSrv.URL, client.Options{}) require.NoError(t, err) + acc, err := wallet.NewAccount() + require.NoError(t, err) t.Run("client wasn't initialized", func(t *testing.T) { - _, err := c.SignAndPushP2PNotaryRequest(nil, nil, 0, 0, 0, nil) + _, err := c.SignAndPushP2PNotaryRequest(transaction.New([]byte{byte(opcode.RET)}, 123), []byte{byte(opcode.RET)}, -1, 0, 100, acc) require.NotNil(t, err) }) @@ -567,8 +569,6 @@ func TestSignAndPushP2PNotaryRequest(t *testing.T) { require.NotNil(t, err) }) - acc, err := wallet.NewAccount() - require.NoError(t, err) t.Run("bad fallback script", func(t *testing.T) { _, err := c.SignAndPushP2PNotaryRequest(nil, []byte{byte(opcode.ASSERT)}, -1, 0, 0, acc) require.NotNil(t, err) @@ -649,7 +649,7 @@ func TestCalculateNotaryFee(t *testing.T) { t.Run("client not initialized", func(t *testing.T) { _, err := c.CalculateNotaryFee(0) - require.NotNil(t, err) + require.NoError(t, err) // Do not require client initialisation for this. }) }