From 845d7196980f059f0ab7f2209d9160cc0b89e980 Mon Sep 17 00:00:00 2001 From: Evgeniy Kulikov Date: Tue, 12 Feb 2019 22:03:21 +0300 Subject: [PATCH] Fix possible data race (#136) - balancer / wif / http.Client could be with data race - add getter / setter with sync.Mutex - now http.Client is pointer - now you can provide your http.Client to rpcClient --- pkg/rpc/client.go | 107 ++++++++++++++++++++++++++++++++-------------- pkg/rpc/rpc.go | 6 +-- 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/pkg/rpc/client.go b/pkg/rpc/client.go index 2e29bfaa5..2df32a43f 100644 --- a/pkg/rpc/client.go +++ b/pkg/rpc/client.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "net/url" + "sync" "time" "github.com/CityOfZion/neo-go/pkg/wallet" @@ -25,23 +26,26 @@ var ( type Client struct { // The underlying http client. It's never a good practice to use // the http.DefaultClient, therefore we will role our own. - http.Client - endpoint *url.URL - ctx context.Context - version string - Wif *wallet.WIF - Balancer BalanceGetter + cliMu *sync.Mutex + cli *http.Client + endpoint *url.URL + ctx context.Context + version string + wifMu *sync.Mutex + wif *wallet.WIF + balancerMu *sync.Mutex + balancer BalanceGetter } // ClientOptions defines options for the RPC client. // All Values are optional. If any duration is not specified // a default of 3 seconds will be used. type ClientOptions struct { - Cert string - Key string - CACert string - DialTimeout time.Duration - RequestTimeout time.Duration + Cert string + Key string + CACert string + DialTimeout time.Duration + Client *http.Client // Version is the version of the client that will be send // along with the request body. If no version is specified // the default version (currently 2.0) will be used. @@ -55,20 +59,18 @@ func NewClient(ctx context.Context, endpoint string, opts ClientOptions) (*Clien return nil, err } - if opts.DialTimeout == 0 { - opts.DialTimeout = defaultDialTimeout - } - if opts.RequestTimeout == 0 { - opts.RequestTimeout = defaultRequestTimeout - } if opts.Version == "" { opts.Version = defaultClientVersion } - transport := &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: opts.DialTimeout, - }).DialContext, + if opts.Client == nil { + opts.Client = &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: opts.DialTimeout, + }).DialContext, + }, + } } // TODO(@antdm): Enable SSL. @@ -76,30 +78,73 @@ func NewClient(ctx context.Context, endpoint string, opts ClientOptions) (*Clien } + if opts.Client.Timeout == 0 { + opts.Client.Timeout = defaultRequestTimeout + } + return &Client{ - Client: http.Client{ - Timeout: opts.RequestTimeout, - Transport: transport, - }, - endpoint: url, - ctx: ctx, - version: opts.Version, + ctx: ctx, + cli: opts.Client, + cliMu: new(sync.Mutex), + balancerMu: new(sync.Mutex), + wifMu: new(sync.Mutex), + endpoint: url, + version: opts.Version, }, nil } +func (c *Client) WIF() wallet.WIF { + c.wifMu.Lock() + defer c.wifMu.Unlock() + return wallet.WIF{ + Version: c.wif.Version, + Compressed: c.wif.Compressed, + PrivateKey: c.wif.PrivateKey, + S: c.wif.S, + } +} + // SetWIF decodes given WIF and adds some wallet // data to client. Useful for RPC calls that require an open wallet. func (c *Client) SetWIF(wif string) error { + c.wifMu.Lock() + defer c.wifMu.Unlock() decodedWif, err := wallet.WIFDecode(wif, 0x00) if err != nil { return errors.Wrap(err, "Failed to decode WIF; failed to add WIF to client ") } - c.Wif = decodedWif + c.wif = decodedWif return nil } +func (c *Client) Balancer() BalanceGetter { + c.balancerMu.Lock() + defer c.balancerMu.Unlock() + return c.balancer +} + func (c *Client) SetBalancer(b BalanceGetter) { - c.Balancer = b + c.balancerMu.Lock() + defer c.balancerMu.Unlock() + + if b != nil { + c.balancer = b + } +} + +func (c *Client) Client() *http.Client { + c.cliMu.Lock() + defer c.cliMu.Unlock() + return c.cli +} + +func (c *Client) SetClient(cli *http.Client) { + c.cliMu.Lock() + defer c.cliMu.Unlock() + + if cli != nil { + c.cli = cli + } } func (c *Client) performRequest(method string, p params, v interface{}) error { @@ -121,7 +166,7 @@ func (c *Client) performRequest(method string, p params, v interface{}) error { if err != nil { return err } - resp, err := c.Do(req) + resp, err := c.Client().Do(req) if err != nil { return err } diff --git a/pkg/rpc/rpc.go b/pkg/rpc/rpc.go index 8b444f2c3..a90f8be54 100644 --- a/pkg/rpc/rpc.go +++ b/pkg/rpc/rpc.go @@ -106,7 +106,7 @@ func (c *Client) SendRawTransaction(rawTX string) (*response, error) { } // SendToAddress sends an amount of specific asset to a given address. -// This call requires open wallet. (`Wif` key in client struct.) +// This call requires open wallet. (`wif` key in client struct.) // If response.Result is `true` then transaction was formed correctly and was written in blockchain. func (c *Client) SendToAddress(asset util.Uint256, address string, amount util.Fixed8) (*SendToAddressResponse, error) { var ( @@ -118,8 +118,8 @@ func (c *Client) SendToAddress(asset util.Uint256, address string, amount util.F assetId: asset, address: address, value: amount, - wif: *c.Wif, - balancer: c.Balancer, + wif: c.WIF(), + balancer: c.Balancer(), } resp *response response = &SendToAddressResponse{}