Fix possible data race (#136)

- balancer / wif / http.Client could be with data race
- add getter / setter with sync.Mutex
- now http.Client is pointer
- now you can provide your http.Client to rpcClient
This commit is contained in:
Evgeniy Kulikov 2019-02-12 22:03:21 +03:00 committed by fabwa
parent 630919bf7d
commit 845d719698
2 changed files with 79 additions and 34 deletions

View file

@ -8,6 +8,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"time" "time"
"github.com/CityOfZion/neo-go/pkg/wallet" "github.com/CityOfZion/neo-go/pkg/wallet"
@ -25,23 +26,26 @@ var (
type Client struct { type Client struct {
// The underlying http client. It's never a good practice to use // The underlying http client. It's never a good practice to use
// the http.DefaultClient, therefore we will role our own. // the http.DefaultClient, therefore we will role our own.
http.Client cliMu *sync.Mutex
endpoint *url.URL cli *http.Client
ctx context.Context endpoint *url.URL
version string ctx context.Context
Wif *wallet.WIF version string
Balancer BalanceGetter wifMu *sync.Mutex
wif *wallet.WIF
balancerMu *sync.Mutex
balancer BalanceGetter
} }
// ClientOptions defines options for the RPC client. // ClientOptions defines options for the RPC client.
// All Values are optional. If any duration is not specified // All Values are optional. If any duration is not specified
// a default of 3 seconds will be used. // a default of 3 seconds will be used.
type ClientOptions struct { type ClientOptions struct {
Cert string Cert string
Key string Key string
CACert string CACert string
DialTimeout time.Duration DialTimeout time.Duration
RequestTimeout time.Duration Client *http.Client
// Version is the version of the client that will be send // Version is the version of the client that will be send
// along with the request body. If no version is specified // along with the request body. If no version is specified
// the default version (currently 2.0) will be used. // the default version (currently 2.0) will be used.
@ -55,20 +59,18 @@ func NewClient(ctx context.Context, endpoint string, opts ClientOptions) (*Clien
return nil, err return nil, err
} }
if opts.DialTimeout == 0 {
opts.DialTimeout = defaultDialTimeout
}
if opts.RequestTimeout == 0 {
opts.RequestTimeout = defaultRequestTimeout
}
if opts.Version == "" { if opts.Version == "" {
opts.Version = defaultClientVersion opts.Version = defaultClientVersion
} }
transport := &http.Transport{ if opts.Client == nil {
DialContext: (&net.Dialer{ opts.Client = &http.Client{
Timeout: opts.DialTimeout, Transport: &http.Transport{
}).DialContext, DialContext: (&net.Dialer{
Timeout: opts.DialTimeout,
}).DialContext,
},
}
} }
// TODO(@antdm): Enable SSL. // TODO(@antdm): Enable SSL.
@ -76,30 +78,73 @@ func NewClient(ctx context.Context, endpoint string, opts ClientOptions) (*Clien
} }
if opts.Client.Timeout == 0 {
opts.Client.Timeout = defaultRequestTimeout
}
return &Client{ return &Client{
Client: http.Client{ ctx: ctx,
Timeout: opts.RequestTimeout, cli: opts.Client,
Transport: transport, cliMu: new(sync.Mutex),
}, balancerMu: new(sync.Mutex),
endpoint: url, wifMu: new(sync.Mutex),
ctx: ctx, endpoint: url,
version: opts.Version, version: opts.Version,
}, nil }, nil
} }
func (c *Client) WIF() wallet.WIF {
c.wifMu.Lock()
defer c.wifMu.Unlock()
return wallet.WIF{
Version: c.wif.Version,
Compressed: c.wif.Compressed,
PrivateKey: c.wif.PrivateKey,
S: c.wif.S,
}
}
// SetWIF decodes given WIF and adds some wallet // SetWIF decodes given WIF and adds some wallet
// data to client. Useful for RPC calls that require an open wallet. // data to client. Useful for RPC calls that require an open wallet.
func (c *Client) SetWIF(wif string) error { func (c *Client) SetWIF(wif string) error {
c.wifMu.Lock()
defer c.wifMu.Unlock()
decodedWif, err := wallet.WIFDecode(wif, 0x00) decodedWif, err := wallet.WIFDecode(wif, 0x00)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to decode WIF; failed to add WIF to client ") return errors.Wrap(err, "Failed to decode WIF; failed to add WIF to client ")
} }
c.Wif = decodedWif c.wif = decodedWif
return nil return nil
} }
func (c *Client) Balancer() BalanceGetter {
c.balancerMu.Lock()
defer c.balancerMu.Unlock()
return c.balancer
}
func (c *Client) SetBalancer(b BalanceGetter) { func (c *Client) SetBalancer(b BalanceGetter) {
c.Balancer = b c.balancerMu.Lock()
defer c.balancerMu.Unlock()
if b != nil {
c.balancer = b
}
}
func (c *Client) Client() *http.Client {
c.cliMu.Lock()
defer c.cliMu.Unlock()
return c.cli
}
func (c *Client) SetClient(cli *http.Client) {
c.cliMu.Lock()
defer c.cliMu.Unlock()
if cli != nil {
c.cli = cli
}
} }
func (c *Client) performRequest(method string, p params, v interface{}) error { func (c *Client) performRequest(method string, p params, v interface{}) error {
@ -121,7 +166,7 @@ func (c *Client) performRequest(method string, p params, v interface{}) error {
if err != nil { if err != nil {
return err return err
} }
resp, err := c.Do(req) resp, err := c.Client().Do(req)
if err != nil { if err != nil {
return err return err
} }

View file

@ -106,7 +106,7 @@ func (c *Client) SendRawTransaction(rawTX string) (*response, error) {
} }
// SendToAddress sends an amount of specific asset to a given address. // SendToAddress sends an amount of specific asset to a given address.
// This call requires open wallet. (`Wif` key in client struct.) // This call requires open wallet. (`wif` key in client struct.)
// If response.Result is `true` then transaction was formed correctly and was written in blockchain. // If response.Result is `true` then transaction was formed correctly and was written in blockchain.
func (c *Client) SendToAddress(asset util.Uint256, address string, amount util.Fixed8) (*SendToAddressResponse, error) { func (c *Client) SendToAddress(asset util.Uint256, address string, amount util.Fixed8) (*SendToAddressResponse, error) {
var ( var (
@ -118,8 +118,8 @@ func (c *Client) SendToAddress(asset util.Uint256, address string, amount util.F
assetId: asset, assetId: asset,
address: address, address: address,
value: amount, value: amount,
wif: *c.Wif, wif: c.WIF(),
balancer: c.Balancer, balancer: c.Balancer(),
} }
resp *response resp *response
response = &SendToAddressResponse{} response = &SendToAddressResponse{}