Merge pull request #2367 from nspcc-dev/rpc/thread-safe

rpc: take care of RPC clients
This commit is contained in:
Roman Khimov 2022-02-24 20:11:15 +03:00 committed by GitHub
commit 870fd024c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 369 additions and 129 deletions

View file

@ -682,7 +682,11 @@ func invokeWithArgs(ctx *cli.Context, acc *wallet.Account, wall *wallet.Wallet,
if err != nil {
return sender, cli.NewExitError(fmt.Errorf("failed to create tx: %w", err), 1)
}
if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, out); err != nil {
m, err := c.GetNetwork()
if err != nil {
return sender, cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1)
}
if err := paramcontext.InitAndSave(m, tx, acc, out); err != nil {
return sender, cli.NewExitError(err, 1)
}
fmt.Fprintln(ctx.App.Writer, tx.Hash().StringLE())

View file

@ -274,7 +274,11 @@ func signAndSendNEP11Transfer(ctx *cli.Context, c *client.Client, acc *wallet.Ac
tx.SystemFee += int64(sysgas)
if outFile := ctx.String("out"); outFile != "" {
if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, outFile); err != nil {
m, err := c.GetNetwork()
if err != nil {
return cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1)
}
if err := paramcontext.InitAndSave(m, tx, acc, outFile); err != nil {
return cli.NewExitError(err, 1)
}
} else {

View file

@ -647,7 +647,11 @@ func signAndSendNEP17Transfer(ctx *cli.Context, c *client.Client, acc *wallet.Ac
tx.SystemFee += int64(sysgas)
if outFile := ctx.String("out"); outFile != "" {
if err := paramcontext.InitAndSave(c.GetNetwork(), tx, acc, outFile); err != nil {
m, err := c.GetNetwork()
if err != nil {
return cli.NewExitError(fmt.Errorf("failed to save tx: %w", err), 1)
}
if err := paramcontext.InitAndSave(m, tx, acc, outFile); err != nil {
return cli.NewExitError(err, 1)
}
} else {

View file

@ -9,6 +9,7 @@ import (
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
@ -16,6 +17,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
"github.com/nspcc-dev/neo-go/pkg/util"
"go.uber.org/atomic"
)
const (
@ -26,17 +28,26 @@ const (
)
// Client represents the middleman for executing JSON RPC calls
// to remote NEO RPC nodes.
// to remote NEO RPC nodes. Client is thread-safe and can be used from
// multiple goroutines.
type Client struct {
cli *http.Client
endpoint *url.URL
network netmode.Magic
stateRootInHeader bool
initDone bool
ctx context.Context
opts Options
requestF func(*request.Raw) (*response.Raw, error)
cacheLock sync.RWMutex
// cache stores RPC node related information client is bound to.
// cache is mostly filled in during Init(), but can also be updated
// during regular Client lifecycle.
cache cache
latestReqID *atomic.Uint64
// getNextRequestID returns ID to be used for subsequent request creation.
// It is defined on Client so that our testing code can override this method
// for the sake of more predictable request IDs generation behaviour.
getNextRequestID func() uint64
}
// Options defines options for the RPC client.
@ -56,6 +67,9 @@ type Options struct {
// cache stores cache values for the RPC client methods.
type cache struct {
initDone bool
network netmode.Magic
stateRootInHeader bool
calculateValidUntilBlock calculateValidUntilBlockCache
nativeHashes map[string]util.Uint160
}
@ -70,10 +84,19 @@ type calculateValidUntilBlockCache struct {
// New returns a new Client ready to use. You should call Init method to
// initialize network magic the client is operating on.
func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
url, err := url.Parse(endpoint)
cl := new(Client)
err := initClient(ctx, cl, endpoint, opts)
if err != nil {
return nil, err
}
return cl, nil
}
func initClient(ctx context.Context, cl *Client, endpoint string, opts Options) error {
url, err := url.Parse(endpoint)
if err != nil {
return err
}
if opts.DialTimeout <= 0 {
opts.DialTimeout = defaultDialTimeout
@ -97,33 +120,41 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
// if opts.Cert != "" && opts.Key != "" {
// }
cl := &Client{
ctx: ctx,
cli: httpClient,
endpoint: url,
cache: cache{
cl.ctx = ctx
cl.cli = httpClient
cl.endpoint = url
cl.cache = cache{
nativeHashes: make(map[string]util.Uint160),
},
}
cl.latestReqID = atomic.NewUint64(0)
cl.getNextRequestID = (cl).getRequestID
cl.opts = opts
cl.requestF = cl.makeHTTPRequest
return cl, nil
return nil
}
func (c *Client) getRequestID() uint64 {
return c.latestReqID.Inc()
}
// Init sets magic of the network client connected to, stateRootInHeader option
// and native NEO, GAS and Policy contracts scripthashes. This method should be
// called before any transaction-, header- or block-related requests in order to
// deserialize responses properly.
// called before any header- or block-related requests in order to deserialize
// responses properly.
func (c *Client) Init() error {
version, err := c.GetVersion()
if err != nil {
return fmt.Errorf("failed to get network magic: %w", err)
}
c.network = version.Protocol.Network
c.stateRootInHeader = version.Protocol.StateRootInHeader
c.cacheLock.Lock()
defer c.cacheLock.Unlock()
c.cache.network = version.Protocol.Network
c.cache.stateRootInHeader = version.Protocol.StateRootInHeader
if version.Protocol.MillisecondsPerBlock == 0 {
c.network = version.Magic
c.stateRootInHeader = version.StateRootInHeader
c.cache.network = version.Magic
c.cache.stateRootInHeader = version.StateRootInHeader
}
neoContractHash, err := c.GetContractStateByAddressOrName(nativenames.Neo)
if err != nil {
@ -140,7 +171,7 @@ func (c *Client) Init() error {
return fmt.Errorf("failed to get Policy contract scripthash: %w", err)
}
c.cache.nativeHashes[nativenames.Policy] = policyContractHash.Hash
c.initDone = true
c.cache.initDone = true
return nil
}
@ -149,7 +180,7 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{
JSONRPC: request.JSONRPCVersion,
Method: method,
RawParams: p.Values,
ID: 1,
ID: c.getNextRequestID(),
}
raw, err := c.requestF(&r)

View file

@ -46,9 +46,6 @@ func (c *Client) NEP11TokenInfo(tokenHash util.Uint160) (*wallet.Token, error) {
// given account and sends it to the network returning just a hash of it.
func (c *Client) TransferNEP11(acc *wallet.Account, to util.Uint160,
tokenHash util.Uint160, tokenID string, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) {
if !c.initDone {
return util.Uint256{}, errNetworkNotInitialized
}
tx, err := c.CreateNEP11TransferTx(acc, tokenHash, gas, cosigners, to, tokenID, data)
if err != nil {
return util.Uint256{}, err
@ -144,9 +141,6 @@ func (c *Client) NEP11NDOwnerOf(tokenHash util.Uint160, tokenID []byte) (util.Ui
// sends it to the network returning just a hash of it.
func (c *Client) TransferNEP11D(acc *wallet.Account, to util.Uint160,
tokenHash util.Uint160, amount int64, tokenID []byte, data interface{}, gas int64, cosigners []SignerAccount) (util.Uint256, error) {
if !c.initDone {
return util.Uint256{}, errNetworkNotInitialized
}
from, err := address.StringToUint160(acc.Address)
if err != nil {
return util.Uint256{}, fmt.Errorf("bad account address: %w", err)

View file

@ -142,10 +142,6 @@ func (c *Client) CreateTxFromScript(script []byte, acc *wallet.Account, sysFee,
// impossible (e.g. due to locked cosigner's account) an error is returned.
func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util.Uint160,
amount int64, gas int64, data interface{}, cosigners []SignerAccount) (util.Uint256, error) {
if !c.initDone {
return util.Uint256{}, errNetworkNotInitialized
}
tx, err := c.CreateNEP17TransferTx(acc, to, token, amount, gas, data, cosigners)
if err != nil {
return util.Uint256{}, err
@ -156,10 +152,6 @@ func (c *Client) TransferNEP17(acc *wallet.Account, to util.Uint160, token util.
// MultiTransferNEP17 is similar to TransferNEP17, buf allows to have multiple recipients.
func (c *Client) MultiTransferNEP17(acc *wallet.Account, gas int64, recipients []TransferTarget, cosigners []SignerAccount) (util.Uint256, error) {
if !c.initDone {
return util.Uint256{}, errNetworkNotInitialized
}
tx, err := c.CreateNEP17MultiTransferTx(acc, gas, recipients, cosigners)
if err != nil {
return util.Uint256{}, err

View file

@ -10,25 +10,16 @@ import (
// GetFeePerByte invokes `getFeePerByte` method on a native Policy contract.
func (c *Client) GetFeePerByte() (int64, error) {
if !c.initDone {
return 0, errNetworkNotInitialized
}
return c.invokeNativePolicyMethod("getFeePerByte")
}
// GetExecFeeFactor invokes `getExecFeeFactor` method on a native Policy contract.
func (c *Client) GetExecFeeFactor() (int64, error) {
if !c.initDone {
return 0, errNetworkNotInitialized
}
return c.invokeNativePolicyMethod("getExecFeeFactor")
}
// GetStoragePrice invokes `getStoragePrice` method on a native Policy contract.
func (c *Client) GetStoragePrice() (int64, error) {
if !c.initDone {
return 0, errNetworkNotInitialized
}
return c.invokeNativePolicyMethod("getStoragePrice")
}
@ -43,10 +34,11 @@ func (c *Client) GetMaxNotValidBeforeDelta() (int64, error) {
// invokeNativePolicy method invokes Get* method on a native Policy contract.
func (c *Client) invokeNativePolicyMethod(operation string) (int64, error) {
if !c.initDone {
return 0, errNetworkNotInitialized
policyHash, err := c.GetNativeContractHash(nativenames.Policy)
if err != nil {
return 0, fmt.Errorf("failed to get native Policy hash: %w", err)
}
return c.invokeNativeGetMethod(c.cache.nativeHashes[nativenames.Policy], operation)
return c.invokeNativeGetMethod(policyHash, operation)
}
func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int64, error) {
@ -63,10 +55,11 @@ func (c *Client) invokeNativeGetMethod(hash util.Uint160, operation string) (int
// IsBlocked invokes `isBlocked` method on native Policy contract.
func (c *Client) IsBlocked(hash util.Uint160) (bool, error) {
if !c.initDone {
return false, errNetworkNotInitialized
policyHash, err := c.GetNativeContractHash(nativenames.Policy)
if err != nil {
return false, fmt.Errorf("failed to get native Policy hash: %w", err)
}
result, err := c.InvokeFunction(c.cache.nativeHashes[nativenames.Policy], "isBlocked", []smartcontract.Parameter{{
result, err := c.InvokeFunction(policyHash, "isBlocked", []smartcontract.Parameter{{
Type: smartcontract.Hash160Type,
Value: hash,
}}, nil)

View file

@ -94,14 +94,15 @@ func (c *Client) getBlock(params request.RawParams) (*block.Block, error) {
err error
b *block.Block
)
if !c.initDone {
return nil, errNetworkNotInitialized
}
if err = c.performRequest("getblock", params, &resp); err != nil {
return nil, err
}
r := io.NewBinReaderFromBuf(resp)
b = block.New(c.StateRootInHeader())
sr, err := c.StateRootInHeader()
if err != nil {
return nil, err
}
b = block.New(sr)
b.DecodeBinary(r)
if r.Err != nil {
return nil, r.Err
@ -127,9 +128,11 @@ func (c *Client) getBlockVerbose(params request.RawParams) (*result.Block, error
resp = &result.Block{}
err error
)
if !c.initDone {
return nil, errNetworkNotInitialized
sr, err := c.StateRootInHeader()
if err != nil {
return nil, err
}
resp.Header.StateRootEnabled = sr
if err = c.performRequest("getblock", params, resp); err != nil {
return nil, err
}
@ -157,14 +160,16 @@ func (c *Client) GetBlockHeader(hash util.Uint256) (*block.Header, error) {
resp []byte
h *block.Header
)
if !c.initDone {
return nil, errNetworkNotInitialized
}
if err := c.performRequest("getblockheader", params, &resp); err != nil {
return nil, err
}
sr, err := c.StateRootInHeader()
if err != nil {
return nil, err
}
r := io.NewBinReaderFromBuf(resp)
h = new(block.Header)
h.StateRootEnabled = sr
h.DecodeBinary(r)
if r.Err != nil {
return nil, r.Err
@ -266,6 +271,14 @@ func (c *Client) GetNativeContracts() ([]state.NativeContract, error) {
if err := c.performRequest("getnativecontracts", params, &resp); err != nil {
return resp, err
}
// Update native contract hashes.
c.cacheLock.Lock()
for _, cs := range resp {
c.cache.nativeHashes[cs.Manifest.Name] = cs.Hash
}
c.cacheLock.Unlock()
return resp, nil
}
@ -399,17 +412,13 @@ func (c *Client) GetRawMemPool() ([]util.Uint256, error) {
return *resp, nil
}
// GetRawTransaction returns a transaction by hash. You should initialize network magic
// with Init before calling GetRawTransaction.
// GetRawTransaction returns a transaction by hash.
func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction, error) {
var (
params = request.NewRawParams(hash.StringLE())
resp []byte
err error
)
if !c.initDone {
return nil, errNetworkNotInitialized
}
if err = c.performRequest("getrawtransaction", params, &resp); err != nil {
return nil, err
}
@ -421,8 +430,7 @@ func (c *Client) GetRawTransaction(hash util.Uint256) (*transaction.Transaction,
}
// GetRawTransactionVerbose returns a transaction wrapper with additional
// metadata by transaction's hash. You should initialize network magic
// with Init before calling GetRawTransactionVerbose.
// metadata by transaction's hash.
// NOTE: to get transaction.ID and transaction.Size, use t.Hash() and io.GetVarSize(t) respectively.
func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.TransactionOutputRaw, error) {
var (
@ -430,9 +438,6 @@ func (c *Client) GetRawTransactionVerbose(hash util.Uint256) (*result.Transactio
resp = &result.TransactionOutputRaw{}
err error
)
if !c.initDone {
return nil, errNetworkNotInitialized
}
if err = c.performRequest("getrawtransaction", params, resp); err != nil {
return nil, err
}
@ -687,7 +692,11 @@ func (c *Client) SignAndPushTx(tx *transaction.Transaction, acc *wallet.Account,
txHash util.Uint256
err error
)
if err = acc.SignTx(c.GetNetwork(), tx); err != nil {
m, err := c.GetNetwork()
if err != nil {
return txHash, fmt.Errorf("failed to sign tx: %w", err)
}
if err = acc.SignTx(m, tx); err != nil {
return txHash, fmt.Errorf("failed to sign tx: %w", err)
}
// try to add witnesses for the rest of the signers
@ -695,7 +704,7 @@ func (c *Client) SignAndPushTx(tx *transaction.Transaction, acc *wallet.Account,
var isOk bool
for _, cosigner := range cosigners {
if signer.Account == cosigner.Signer.Account {
err = cosigner.Account.SignTx(c.GetNetwork(), tx)
err = cosigner.Account.SignTx(m, tx)
if err != nil { // then account is non-contract-based and locked, but let's provide more detailed error
if paramNum := len(cosigner.Account.Contract.Parameters); paramNum != 0 && cosigner.Account.Contract.Deployed {
return txHash, fmt.Errorf("failed to add contract-based witness for signer #%d (%s): "+
@ -771,9 +780,6 @@ func getSigners(sender *wallet.Account, cosigners []SignerAccount) ([]transactio
// Note: client should be initialized before SignAndPushP2PNotaryRequest call.
func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fallbackScript []byte, fallbackSysFee int64, fallbackNetFee int64, fallbackValidFor uint32, acc *wallet.Account) (*payload.P2PNotaryRequest, error) {
var err error
if !c.initDone {
return nil, errNetworkNotInitialized
}
notaryHash, err := c.GetNativeContractHash(nativenames.Notary)
if err != nil {
return nil, fmt.Errorf("failed to get native Notary hash: %w", err)
@ -835,7 +841,11 @@ func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fa
VerificationScript: []byte{},
},
}
if err = acc.SignTx(c.GetNetwork(), fallbackTx); err != nil {
m, err := c.GetNetwork()
if err != nil {
return nil, fmt.Errorf("failed to sign fallback tx: %w", err)
}
if err = acc.SignTx(m, fallbackTx); err != nil {
return nil, fmt.Errorf("failed to sign fallback tx: %w", err)
}
fallbackHash := fallbackTx.Hash()
@ -844,7 +854,7 @@ func (c *Client) SignAndPushP2PNotaryRequest(mainTx *transaction.Transaction, fa
FallbackTransaction: fallbackTx,
}
req.Witness = transaction.Witness{
InvocationScript: append([]byte{byte(opcode.PUSHDATA1), 64}, acc.PrivateKey().SignHashable(uint32(c.GetNetwork()), req)...),
InvocationScript: append([]byte{byte(opcode.PUSHDATA1), 64}, acc.PrivateKey().SignHashable(uint32(m), req)...),
VerificationScript: acc.GetVerificationScript(),
}
actualHash, err := c.SubmitP2PNotaryRequest(req)
@ -922,18 +932,23 @@ func (c *Client) CalculateValidUntilBlock() (uint32, error) {
return result, fmt.Errorf("can't get block count: %w", err)
}
c.cacheLock.RLock()
if c.cache.calculateValidUntilBlock.expiresAt > blockCount {
validatorsCount = c.cache.calculateValidUntilBlock.validatorsCount
c.cacheLock.RUnlock()
} else {
c.cacheLock.RUnlock()
validators, err := c.GetNextBlockValidators()
if err != nil {
return result, fmt.Errorf("can't get validators: %w", err)
}
validatorsCount = uint32(len(validators))
c.cacheLock.Lock()
c.cache.calculateValidUntilBlock = calculateValidUntilBlockCache{
validatorsCount: validatorsCount,
expiresAt: blockCount + cacheTimeout,
}
c.cacheLock.Unlock()
}
return blockCount + validatorsCount + 1, nil
}
@ -990,18 +1005,33 @@ func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs
}
// GetNetwork returns the network magic of the RPC node client connected to.
func (c *Client) GetNetwork() netmode.Magic {
return c.network
func (c *Client) GetNetwork() (netmode.Magic, error) {
c.cacheLock.RLock()
defer c.cacheLock.RUnlock()
if !c.cache.initDone {
return 0, errNetworkNotInitialized
}
return c.cache.network, nil
}
// StateRootInHeader returns true if state root is contained in block header.
func (c *Client) StateRootInHeader() bool {
return c.stateRootInHeader
// You should initialize Client cache with Init() before calling StateRootInHeader.
func (c *Client) StateRootInHeader() (bool, error) {
c.cacheLock.RLock()
defer c.cacheLock.RUnlock()
if !c.cache.initDone {
return false, errNetworkNotInitialized
}
return c.cache.stateRootInHeader, nil
}
// GetNativeContractHash returns native contract hash by its name.
func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
c.cacheLock.RLock()
hash, ok := c.cache.nativeHashes[name]
c.cacheLock.RUnlock()
if ok {
return hash, nil
}
@ -1009,6 +1039,8 @@ func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
if err != nil {
return util.Uint160{}, err
}
c.cacheLock.Lock()
c.cache.nativeHashes[name] = cs.Hash
c.cacheLock.Unlock()
return cs.Hash, nil
}

View file

@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
@ -1704,6 +1705,7 @@ func TestRPCClients(t *testing.T) {
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
c, err := New(ctx, endpoint, opts)
require.NoError(t, err)
c.getNextRequestID = getTestRequestID
require.NoError(t, c.Init())
return c, nil
})
@ -1712,6 +1714,7 @@ func TestRPCClients(t *testing.T) {
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init())
return &wsc.Client, nil
})
@ -1731,6 +1734,7 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
if err != nil {
t.Fatal(err)
}
c.getNextRequestID = getTestRequestID
actual, err := testCase.invoke(c)
if testCase.fails {
@ -1754,14 +1758,14 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
endpoint := srv.URL
opts := Options{}
for _, testCase := range testBatch {
t.Run(testCase.name, func(t *testing.T) {
c, err := newClient(context.TODO(), endpoint, opts)
if err != nil {
t.Fatal(err)
}
for _, testCase := range testBatch {
t.Run(testCase.name, func(t *testing.T) {
_, err := testCase.invoke(c)
c.getNextRequestID = getTestRequestID
_, err = testCase.invoke(c)
assert.Error(t, err)
})
}
@ -1877,6 +1881,7 @@ func TestCalculateValidUntilBlock(t *testing.T) {
if err != nil {
t.Fatal(err)
}
c.getNextRequestID = getTestRequestID
require.NoError(t, c.Init())
validUntilBlock, err := c.CalculateValidUntilBlock()
@ -1912,9 +1917,11 @@ func TestGetNetwork(t *testing.T) {
if err != nil {
t.Fatal(err)
}
c.getNextRequestID = getTestRequestID
// network was not initialised
require.Equal(t, netmode.Magic(0), c.GetNetwork())
require.Equal(t, false, c.initDone)
_, err = c.GetNetwork()
require.True(t, errors.Is(err, errNetworkNotInitialized))
require.Equal(t, false, c.cache.initDone)
})
t.Run("good", func(t *testing.T) {
@ -1922,8 +1929,11 @@ func TestGetNetwork(t *testing.T) {
if err != nil {
t.Fatal(err)
}
c.getNextRequestID = getTestRequestID
require.NoError(t, c.Init())
require.Equal(t, netmode.UnitTestNet, c.GetNetwork())
m, err := c.GetNetwork()
require.NoError(t, err)
require.Equal(t, netmode.UnitTestNet, m)
})
}
@ -1941,6 +1951,7 @@ func TestUninitedClient(t *testing.T) {
c, err := New(context.TODO(), endpoint, opts)
require.NoError(t, err)
c.getNextRequestID = getTestRequestID
_, err = c.GetBlockByIndex(0)
require.Error(t, err)
@ -1966,3 +1977,7 @@ func newTestNEF(script []byte) nef.File {
ne.Checksum = ne.CalculateChecksum()
return ne
}
func getTestRequestID() uint64 {
return 1
}

View file

@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"errors"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
@ -20,6 +22,8 @@ import (
// servers. It's supposed to be faster than Client because it has persistent
// connection to the server and at the same time is exposes some functionality
// that is only provided via websockets (like event subscription mechanism).
// WSClient is thread-safe and can be used from multiple goroutines to perform
// RPC requests.
type WSClient struct {
Client
// Notifications is a channel that is used to send events received from
@ -32,10 +36,14 @@ type WSClient struct {
ws *websocket.Conn
done chan struct{}
responses chan *response.Raw
requests chan *request.Raw
shutdown chan struct{}
subscriptionsLock sync.RWMutex
subscriptions map[string]bool
respLock sync.RWMutex
respChannels map[uint64]chan *response.Raw
}
// Notification represents server-generated notification for client subscriptions.
@ -73,29 +81,29 @@ const (
// You should call Init method to initialize network magic the client is
// operating on.
func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) {
cl, err := New(ctx, endpoint, opts)
if err != nil {
return nil, err
}
cl.cli = nil
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
ws, _, err := dialer.Dial(endpoint, nil)
if err != nil {
return nil, err
}
wsc := &WSClient{
Client: *cl,
Client: Client{},
Notifications: make(chan Notification),
ws: ws,
shutdown: make(chan struct{}),
done: make(chan struct{}),
responses: make(chan *response.Raw),
respChannels: make(map[uint64]chan *response.Raw),
requests: make(chan *request.Raw),
subscriptions: make(map[string]bool),
}
err = initClient(ctx, &wsc.Client, endpoint, opts)
if err != nil {
return nil, err
}
wsc.Client.cli = nil
go wsc.wsReader()
go wsc.wsWriter()
wsc.requestF = wsc.makeWsRequest
@ -141,7 +149,12 @@ readloop:
var val interface{}
switch event {
case response.BlockEventID:
val = block.New(c.StateRootInHeader())
sr, err := c.StateRootInHeader()
if err != nil {
// Client is not initialised.
break
}
val = block.New(sr)
case response.TransactionEventID:
val = &transaction.Transaction{}
case response.NotificationEventID:
@ -170,14 +183,27 @@ readloop:
resp.JSONRPC = rr.JSONRPC
resp.Error = rr.Error
resp.Result = rr.Result
c.responses <- resp
id, err := strconv.Atoi(string(resp.ID))
if err != nil {
break // Malformed response (invalid response ID).
}
ch := c.getResponseChannel(uint64(id))
if ch == nil {
break // Unknown response (unexpected response ID).
}
ch <- resp
} else {
// Malformed response, neither valid request, nor valid response.
break
}
}
close(c.done)
close(c.responses)
c.respLock.Lock()
for _, ch := range c.respChannels {
close(ch)
}
c.respChannels = nil
c.respLock.Unlock()
close(c.Notifications)
}
@ -212,16 +238,41 @@ func (c *WSClient) wsWriter() {
}
}
func (c *WSClient) registerRespChannel(id uint64, ch chan *response.Raw) {
c.respLock.Lock()
defer c.respLock.Unlock()
c.respChannels[id] = ch
}
func (c *WSClient) unregisterRespChannel(id uint64) {
c.respLock.Lock()
defer c.respLock.Unlock()
if ch, ok := c.respChannels[id]; ok {
delete(c.respChannels, id)
close(ch)
}
}
func (c *WSClient) getResponseChannel(id uint64) chan *response.Raw {
c.respLock.RLock()
defer c.respLock.RUnlock()
return c.respChannels[id]
}
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
ch := make(chan *response.Raw)
c.registerRespChannel(r.ID, ch)
select {
case <-c.done:
return nil, errors.New("connection lost")
return nil, errors.New("connection lost before sending the request")
case c.requests <- r:
}
select {
case <-c.done:
return nil, errors.New("connection lost")
case resp := <-c.responses:
return nil, errors.New("connection lost while waiting for the response")
case resp := <-ch:
c.unregisterRespChannel(r.ID)
return resp, nil
}
}
@ -232,6 +283,10 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
if err := c.performRequest("subscribe", params, &resp); err != nil {
return "", err
}
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
c.subscriptions[resp] = true
return resp, nil
}
@ -239,6 +294,9 @@ func (c *WSClient) performSubscription(params request.RawParams) (string, error)
func (c *WSClient) performUnsubscription(id string) error {
var resp bool
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
if !c.subscriptions[id] {
return errors.New("no subscription with this ID")
}
@ -320,11 +378,18 @@ func (c *WSClient) Unsubscribe(id string) error {
// UnsubscribeAll removes all active subscriptions of current client.
func (c *WSClient) UnsubscribeAll() error {
c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock()
for id := range c.subscriptions {
err := c.performUnsubscription(id)
if err != nil {
var resp bool
if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil {
return err
}
if !resp {
return errors.New("unsubscribe method returned false result")
}
delete(c.subscriptions, id)
}
return nil
}

View file

@ -6,7 +6,10 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"sort"
"strconv"
"strings"
"sync"
"testing"
"time"
@ -15,6 +18,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
)
func TestWSClientClose(t *testing.T) {
@ -45,6 +49,7 @@ func TestWSClientSubscription(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init())
id, err := f(wsc)
require.NoError(t, err)
@ -58,6 +63,7 @@ func TestWSClientSubscription(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init())
_, err = f(wsc)
require.Error(t, err)
@ -107,6 +113,7 @@ func TestWSClientUnsubscription(t *testing.T) {
srv := initTestServer(t, rc.response)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init())
rc.code(t, wsc)
})
@ -143,7 +150,9 @@ func TestWSClientEvents(t *testing.T) {
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
wsc.network = netmode.UnitTestNet
wsc.getNextRequestID = getTestRequestID
wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually.
wsc.cache.network = netmode.UnitTestNet
for range events {
select {
case _, ok = <-wsc.Notifications:
@ -166,6 +175,7 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init())
filter := "NONE"
_, err = wsc.SubscribeForTransactionExecutions(&filter)
@ -315,7 +325,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
}))
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
wsc.network = netmode.UnitTestNet
wsc.getNextRequestID = getTestRequestID
wsc.cache.network = netmode.UnitTestNet
c.clientCode(t, wsc)
wsc.Close()
})
@ -328,6 +339,8 @@ func TestNewWS(t *testing.T) {
t.Run("good", func(t *testing.T) {
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
c.getNextRequestID = getTestRequestID
c.cache.network = netmode.UnitTestNet
require.NoError(t, c.Init())
})
t.Run("bad URL", func(t *testing.T) {
@ -335,3 +348,96 @@ func TestNewWS(t *testing.T) {
require.Error(t, err)
})
}
func TestWSConcurrentAccess(t *testing.T) {
var ids struct {
lock sync.RWMutex
m map[int]struct{}
}
ids.m = make(map[int]struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/ws" && req.Method == "GET" {
var upgrader = websocket.Upgrader{}
ws, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
for {
err = ws.SetReadDeadline(time.Now().Add(2 * time.Second))
require.NoError(t, err)
_, p, err := ws.ReadMessage()
if err != nil {
break
}
r := request.NewIn()
err = json.Unmarshal(p, r)
if err != nil {
t.Fatalf("Cannot decode request body: %s", req.Body)
}
i, err := strconv.Atoi(string(r.RawID))
require.NoError(t, err)
ids.lock.Lock()
ids.m[i] = struct{}{}
ids.lock.Unlock()
var response string
// Different responses to catch possible unmarshalling errors connected with invalid IDs distribution.
switch r.Method {
case "getblockcount":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":123}`, r.RawID)
case "getversion":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":{"network":42,"tcpport":20332,"wsport":20342,"nonce":2153672787,"useragent":"/NEO-GO:0.73.1-pre-273-ge381358/"}}`, r.RawID)
case "getblockhash":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":"0x157ca5e5b8cf8f84c9660502a3270b346011612bded1514a6847f877c433a9bb"}`, r.RawID)
}
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second))
require.NoError(t, err)
err = ws.WriteMessage(1, []byte(response))
if err != nil {
break
}
}
ws.Close()
return
}
}))
t.Cleanup(srv.Close)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
batchCount := 100
completed := atomic.NewInt32(0)
for i := 0; i < batchCount; i++ {
go func() {
_, err := wsc.GetBlockCount()
require.NoError(t, err)
completed.Inc()
}()
go func() {
_, err := wsc.GetBlockHash(123)
require.NoError(t, err)
completed.Inc()
}()
go func() {
_, err := wsc.GetVersion()
require.NoError(t, err)
completed.Inc()
}()
}
require.Eventually(t, func() bool {
return int(completed.Load()) == batchCount*3
}, time.Second, 100*time.Millisecond)
ids.lock.RLock()
require.True(t, len(ids.m) > batchCount)
idsList := make([]int, 0, len(ids.m))
for i := range ids.m {
idsList = append(idsList, i)
}
ids.lock.RUnlock()
sort.Ints(idsList)
require.Equal(t, 1, idsList[0])
require.Less(t, idsList[len(idsList)-1],
batchCount*3+1) // batchCount*requestsPerBatch+1
wsc.Close()
}

View file

@ -32,12 +32,12 @@ func NewRawParams(vals ...interface{}) RawParams {
return p
}
// Raw represents JSON-RPC request.
// Raw represents JSON-RPC request on the Client side.
type Raw struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
RawParams []interface{} `json:"params"`
ID int `json:"id"`
ID uint64 `json:"id"`
}
// Request contains standard JSON-RPC 2.0 request and batch of

View file

@ -555,9 +555,11 @@ func TestSignAndPushP2PNotaryRequest(t *testing.T) {
c, err := client.New(context.Background(), httpSrv.URL, client.Options{})
require.NoError(t, err)
acc, err := wallet.NewAccount()
require.NoError(t, err)
t.Run("client wasn't initialized", func(t *testing.T) {
_, err := c.SignAndPushP2PNotaryRequest(nil, nil, 0, 0, 0, nil)
_, err := c.SignAndPushP2PNotaryRequest(transaction.New([]byte{byte(opcode.RET)}, 123), []byte{byte(opcode.RET)}, -1, 0, 100, acc)
require.NotNil(t, err)
})
@ -567,8 +569,6 @@ func TestSignAndPushP2PNotaryRequest(t *testing.T) {
require.NotNil(t, err)
})
acc, err := wallet.NewAccount()
require.NoError(t, err)
t.Run("bad fallback script", func(t *testing.T) {
_, err := c.SignAndPushP2PNotaryRequest(nil, []byte{byte(opcode.ASSERT)}, -1, 0, 0, acc)
require.NotNil(t, err)
@ -649,7 +649,7 @@ func TestCalculateNotaryFee(t *testing.T) {
t.Run("client not initialized", func(t *testing.T) {
_, err := c.CalculateNotaryFee(0)
require.NotNil(t, err)
require.NoError(t, err) // Do not require client initialisation for this.
})
}