Fix possible data race (#136)

- balancer / wif / http.Client could be with data race
- add getter / setter with sync.Mutex
- now http.Client is pointer
- now you can provide your http.Client to rpcClient
This commit is contained in:
Evgeniy Kulikov 2019-02-12 22:03:21 +03:00 committed by fabwa
parent 630919bf7d
commit 845d719698
2 changed files with 79 additions and 34 deletions

View file

@ -8,6 +8,7 @@ import (
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/CityOfZion/neo-go/pkg/wallet"
@ -25,12 +26,15 @@ var (
type Client struct {
// The underlying http client. It's never a good practice to use
// the http.DefaultClient, therefore we will role our own.
http.Client
cliMu *sync.Mutex
cli *http.Client
endpoint *url.URL
ctx context.Context
version string
Wif *wallet.WIF
Balancer BalanceGetter
wifMu *sync.Mutex
wif *wallet.WIF
balancerMu *sync.Mutex
balancer BalanceGetter
}
// ClientOptions defines options for the RPC client.
@ -41,7 +45,7 @@ type ClientOptions struct {
Key string
CACert string
DialTimeout time.Duration
RequestTimeout time.Duration
Client *http.Client
// Version is the version of the client that will be send
// along with the request body. If no version is specified
// the default version (currently 2.0) will be used.
@ -55,20 +59,18 @@ func NewClient(ctx context.Context, endpoint string, opts ClientOptions) (*Clien
return nil, err
}
if opts.DialTimeout == 0 {
opts.DialTimeout = defaultDialTimeout
}
if opts.RequestTimeout == 0 {
opts.RequestTimeout = defaultRequestTimeout
}
if opts.Version == "" {
opts.Version = defaultClientVersion
}
transport := &http.Transport{
if opts.Client == nil {
opts.Client = &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: opts.DialTimeout,
}).DialContext,
},
}
}
// TODO(@antdm): Enable SSL.
@ -76,30 +78,73 @@ func NewClient(ctx context.Context, endpoint string, opts ClientOptions) (*Clien
}
if opts.Client.Timeout == 0 {
opts.Client.Timeout = defaultRequestTimeout
}
return &Client{
Client: http.Client{
Timeout: opts.RequestTimeout,
Transport: transport,
},
endpoint: url,
ctx: ctx,
cli: opts.Client,
cliMu: new(sync.Mutex),
balancerMu: new(sync.Mutex),
wifMu: new(sync.Mutex),
endpoint: url,
version: opts.Version,
}, nil
}
func (c *Client) WIF() wallet.WIF {
c.wifMu.Lock()
defer c.wifMu.Unlock()
return wallet.WIF{
Version: c.wif.Version,
Compressed: c.wif.Compressed,
PrivateKey: c.wif.PrivateKey,
S: c.wif.S,
}
}
// SetWIF decodes given WIF and adds some wallet
// data to client. Useful for RPC calls that require an open wallet.
func (c *Client) SetWIF(wif string) error {
c.wifMu.Lock()
defer c.wifMu.Unlock()
decodedWif, err := wallet.WIFDecode(wif, 0x00)
if err != nil {
return errors.Wrap(err, "Failed to decode WIF; failed to add WIF to client ")
}
c.Wif = decodedWif
c.wif = decodedWif
return nil
}
func (c *Client) Balancer() BalanceGetter {
c.balancerMu.Lock()
defer c.balancerMu.Unlock()
return c.balancer
}
func (c *Client) SetBalancer(b BalanceGetter) {
c.Balancer = b
c.balancerMu.Lock()
defer c.balancerMu.Unlock()
if b != nil {
c.balancer = b
}
}
func (c *Client) Client() *http.Client {
c.cliMu.Lock()
defer c.cliMu.Unlock()
return c.cli
}
func (c *Client) SetClient(cli *http.Client) {
c.cliMu.Lock()
defer c.cliMu.Unlock()
if cli != nil {
c.cli = cli
}
}
func (c *Client) performRequest(method string, p params, v interface{}) error {
@ -121,7 +166,7 @@ func (c *Client) performRequest(method string, p params, v interface{}) error {
if err != nil {
return err
}
resp, err := c.Do(req)
resp, err := c.Client().Do(req)
if err != nil {
return err
}

View file

@ -106,7 +106,7 @@ func (c *Client) SendRawTransaction(rawTX string) (*response, error) {
}
// SendToAddress sends an amount of specific asset to a given address.
// This call requires open wallet. (`Wif` key in client struct.)
// This call requires open wallet. (`wif` key in client struct.)
// If response.Result is `true` then transaction was formed correctly and was written in blockchain.
func (c *Client) SendToAddress(asset util.Uint256, address string, amount util.Fixed8) (*SendToAddressResponse, error) {
var (
@ -118,8 +118,8 @@ func (c *Client) SendToAddress(asset util.Uint256, address string, amount util.F
assetId: asset,
address: address,
value: amount,
wif: *c.Wif,
balancer: c.Balancer,
wif: c.WIF(),
balancer: c.Balancer(),
}
resp *response
response = &SendToAddressResponse{}